From f57ae32696f410c046209236646846b4b87cc2bb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 001/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- docs/res/code/plugin_convert.rst | 5 - napari_cellseg3d/_tests/test_helper.py | 2 +- napari_cellseg3d/_tests/test_interface.py | 3 +- .../_tests/test_plugin_inference.py | 1 + napari_cellseg3d/_tests/test_training.py | 2 +- .../code_models/model_instance_seg.py | 274 +++++++++++++++++- napari_cellseg3d/code_models/model_workers.py | 32 +- .../code_models/models/model_test.py | 4 +- .../code_plugins/plugin_convert.py | 158 +--------- .../code_plugins/plugin_model_inference.py | 25 +- .../code_plugins/plugin_model_training.py | 1 - napari_cellseg3d/config.py | 25 +- napari_cellseg3d/interface.py | 21 +- requirements.txt | 1 + 14 files changed, 308 insertions(+), 246 deletions(-) diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 7a244040..43b8c7be 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -19,11 +19,6 @@ ToSemanticUtils .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToSemanticUtils :members: __init__ -InstanceWidgets -********************************** -.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::InstanceWidgets - :members: __init__, run_method - ToInstanceUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToInstanceUtils diff --git a/napari_cellseg3d/_tests/test_helper.py b/napari_cellseg3d/_tests/test_helper.py index 1f740fe0..b35fc111 100644 --- a/napari_cellseg3d/_tests/test_helper.py +++ b/napari_cellseg3d/_tests/test_helper.py @@ -13,4 +13,4 @@ def test_helper(make_napari_viewer): widget.btnc.click() - assert len(viewer.window._dock_widgets) == children-1 + assert len(viewer.window._dock_widgets) == children - 1 diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py index b5b09238..be811721 100644 --- a/napari_cellseg3d/_tests/test_interface.py +++ b/napari_cellseg3d/_tests/test_interface.py @@ -1,5 +1,6 @@ from napari_cellseg3d.interface import Log + def test_log(qtbot): log = Log() log.print_and_log("test") @@ -10,4 +11,4 @@ def test_log(qtbot): assert log.toPlainText() == "\ntest2" - qtbot.add_widget(log) \ No newline at end of file + qtbot.add_widget(log) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 33eb569b..1ec7e77e 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -6,6 +6,7 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.code_models.models.model_test import TestModel + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 878e72d6..70b79b31 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -34,7 +34,7 @@ def test_training(make_napari_viewer, qtbot): ################# MODEL_LIST["test"] = TestModel() widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys())-1) + widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) # widget.start() # assert widget.worker is not None diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 88940d7d..8b39ad5f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -5,23 +5,63 @@ from typing import List import numpy as np -from skimage.filters import thresholding - -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area +import pyclesperanto_prototype as cle +from qtpy.QtWidgets import QWidget from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed +from skimage.filters import thresholding from skimage.transform import resize + +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes from tifffile import imread +from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from napari_cellseg3d.utils import sphericity_volume_area +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -52,18 +92,43 @@ def get_dict(self): def threshold(volume, thresh): + """Remove all values smaller than the specified threshold in the volume""" im = np.squeeze(volume) binary = im > thresh return np.where(binary, im, np.zeros_like(im)) +def voronoi_otsu( + volume: np.ndarray, + spot_sigma: float, + outline_sigma: float, + remove_small_size: float, +): + """ + Voronoi-Otsu labeling from pyclesperanto. + BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase + https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: + volume (np.ndarray): volume to segment + spot_sigma (float): parameter determining how close detected objects can be + outline_sigma (float): determines the smoothness of the segmentation + remove_small_size (float): remove all objects smaller than the specified size in pixels + + Returns: + Instance segmentation labels from Voronoi-Otsu method + """ + semantic = np.squeeze(volume) + instance = cle.voronoi_otsu_labeling( + semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + ) + # instance = remove_small_objects(instance, remove_small_size) + return instance + + def binary_connected( volume, thres=0.5, thres_small=3, - # scale_factors=(1.0, 1.0, 1.0), - *args, - **kwargs ): r"""Convert binary foreground probability maps to instance masks via connected-component labeling. @@ -99,12 +164,9 @@ def binary_connected( def binary_watershed( volume, thres_objects=0.3, - thres_small=10, thres_seeding=0.9, - # scale_factors=(1.0, 1.0, 1.0), + thres_small=10, rem_seed_thres=3, - *args, - **kwargs ): r"""Convert binary foreground probability maps to instance masks via watershed segmentation algorithm. @@ -115,10 +177,9 @@ def binary_watershed( Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. - thres_seeding (float): threshold for seeding. Default: 0.98 thres_objects (float): threshold for foreground objects. Default: 0.3 + thres_seeding (float): threshold for seeding. Default: 0.9 thres_small (int): size threshold of small objects removal. Default: 10 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) rem_seed_thres (int): threshold for small seeds removal. Default : 3 """ semantic = np.squeeze(volume) @@ -195,7 +256,7 @@ def to_instance(image, is_file_path=False): result = binary_watershed( image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 - ) # TODO add params + ) # FIXME add params from utils plugin return result @@ -285,3 +346,188 @@ def fill(lst, n=len(properties) - 1): ratio, fill([len(properties)]), ) + + +class Watershed(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Watershed", + function=binary_watershed, + num_sliders=2, + num_counters=2, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(50) + + self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].tooltips = "Probability threshold for seeding" + self.sliders[1].setValue(90) + + self.counters[0].label.setText("Small object removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(30) + + self.counters[1].label.setText("Small seed removal") + self.counters[1].tooltips = ( + "Volume/size threshold for small seeds removal." + "\nAll seeds with a volume/size below this value will be removed." + ) + self.counters[1].setValue(3) + + def run_method(self, image): + return self.function( + image, + self.sliders[0].value(), + self.sliders[1].value(), + self.counters[0].value(), + self.counters[1].value(), + ) + + +class ConnectedComponents(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Connected Components", + function=binary_connected, + num_sliders=1, + num_counters=1, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(80) + + self.counters[0].label.setText("Small objects removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(3) + + def run_method(self, image): + return self.function( + image, self.sliders[0].value(), self.counters[0].value() + ) + + +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Voronoi-Otsu", + function=voronoi_otsu, + num_sliders=0, + num_counters=3, + ) + self.counters[0].label.setText("Spot sigma") + self.counters[ + 0 + ].tooltips = "Determines how close detected objects can be" + self.counters[0].setMaximum(100) + self.counters[0].setValue(2) + + self.counters[1].label.setText("Outline sigma") + self.counters[ + 1 + ].tooltips = "Determines the smoothness of the segmentation" + self.counters[1].setMaximum(100) + self.counters[1].setValue(2) + + self.counters[2].label.setText("Small object removal") + self.counters[2].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + + def run_method(self, image): + return self.function( + image, + self.counters[0].value(), + self.counters[1].value(), + self.counters[2].value(), + ) + + +class InstanceWidgets(QWidget): + """ + Base widget with several sliders, for use in instance segmentation parameters + """ + + def __init__(self, parent=None): + """ + Creates an InstanceWidgets widget + + Args: + parent: parent widget + """ + super().__init__(parent) + + self.method_choice = ui.DropdownMenu( + INSTANCE_SEGMENTATION_METHOD_LIST.keys() + ) + self.methods = [] + self.instance_widgets = {} + + self.method_choice.currentTextChanged.connect(self._set_visibility) + self._build() + + def _build(self): + + group = ui.GroupedWidget("Instance segmentation") + group.layout.addWidget(self.method_choice) + + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + self.instance_widgets[name] = [] + if len(method().sliders) > 0: + for slider in method().sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method().counters) > 0: + for counter in method().counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + + self.setLayout(group.layout) + self._set_visibility() + + def _set_visibility(self): + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + + for widget in self.instance_widgets[method.name]: + widget.set_visibility(True) + + for key in self.instance_widgets.keys(): + if key != method.name: + for widget in self.instance_widgets[key]: + widget.set_visibility(False) + + def run_method(self, volume): + """ + Calls instance function with chosen parameters + Args: + volume: image data to run method on + + Returns: processed image from self._method + """ + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + return method.run_method(volume) + + +INSTANCE_SEGMENTATION_METHOD_LIST = { + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, +} diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 0446936b..50116748 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -51,8 +51,6 @@ from napari_cellseg3d import utils # local -from napari_cellseg3d.code_models.model_instance_seg import binary_connected -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed from napari_cellseg3d.code_models.model_instance_seg import ImageStats from napari_cellseg3d.code_models.model_instance_seg import volume_stats @@ -455,7 +453,9 @@ def model_output( inputs = inputs.to("cpu") model_output = lambda inputs: post_process_transforms( - self.config.model_info.get_model().get_output(model, inputs) # TODO(cyril) refactor those functions + self.config.model_info.get_model().get_output( + model, inputs + ) # TODO(cyril) refactor those functions ) def model_output(inputs): @@ -609,30 +609,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - threshold = ( - self.config.post_process_config.instance.threshold.threshold_value - ) - size_small = ( - self.config.post_process_config.instance.small_object_removal_threshold.threshold_value - ) - method_name = self.config.post_process_config.instance.method - - if method_name == "Watershed": # FIXME use dict in config instead - - def method(image): - return binary_watershed(image, threshold, size_small) - - elif method_name == "Connected components": - - def method(image): - return binary_connected(image, threshold, size_small) - - else: - raise NotImplementedError( - "Selected instance segmentation method is not defined" - ) - - instance_labels = method(to_instance) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 09cdd895..5871c4a7 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -7,7 +7,6 @@ def get_weights_file(): class TestModel(nn.Module): - def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) @@ -24,6 +23,7 @@ def get_output(self, _, input): def get_validation(self, val_inputs): return val_inputs + # if __name__ == "__main__": # # model = TestModel() @@ -33,4 +33,4 @@ def get_validation(self, val_inputs): # torch.save( # model.state_dict(), # WEIGHTS_DIR + f"/{get_weights_file()}" -# ) \ No newline at end of file +# ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 37be03c8..08d209b9 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -14,6 +14,7 @@ from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -361,163 +362,6 @@ def _start(self): ) -class InstanceWidgets(QWidget): - """ - Base widget with several sliders, for use in instance segmentation parameters - """ - - def __init__(self, parent=None): - """ - Creates an InstanceWidgets widget - - Args: - parent: parent widget - """ - super().__init__(parent) - - self.method_choice = ui.DropdownMenu( - config.INSTANCE_SEGMENTATION_METHOD_LIST.keys() - ) - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ] - - self.method_choice.currentTextChanged.connect(self._show_connected) - self.method_choice.currentTextChanged.connect(self._show_watershed) - - self.threshold_slider1 = ui.Slider( - lower=0, - upper=100, - default=50, - divide_factor=100.0, - step=5, - text_label="Probability threshold :", - ) - """Base prob. threshold""" - self.threshold_slider2 = ui.Slider( - lower=0, - upper=100, - default=90, - divide_factor=100.0, - step=5, - text_label="Probability threshold (seeding) :", - ) - """Second prob. thresh. (seeding)""" - - self.counter1 = ui.IntIncrementCounter( - upper=100, - default=10, - step=5, - label="Small object removal (pxs) :", - ) - """Small obj. rem.""" - - self.counter2 = ui.IntIncrementCounter( - upper=100, - default=3, - step=5, - label="Small seed removal (pxs) :", - ) - """Small seed rem.""" - - self._build() - - def run_method(self, volume): - """ - Calls instance function with chosen parameters - Args: - volume: image data to run method on - - Returns: processed image from self._method - """ - return self._method( - volume, - self.threshold_slider1.slider_value, - self.counter1.value(), - self.threshold_slider2.slider_value, - self.counter2.value(), - ) - - def _build(self): - - group = ui.GroupedWidget("Instance segmentation") - - ui.add_widgets( - group.layout, - [ - self.method_choice, - self.threshold_slider1.container, - self.threshold_slider2.container, - self.counter1.label, - self.counter1, - self.counter2.label, - self.counter2, - ], - ) - - self.setLayout(group.layout) - self._set_tooltips() - - def _set_tooltips(self): - - self.method_choice.setToolTip( - "Choose which method to use for instance segmentation" - "\nConnected components : all separated objects will be assigned an unique ID. " - "Robust but will not work correctly with adjacent/touching objects\n" - "Watershed : assigns objects ID based on the probability gradient surrounding an object. " - "Requires the model to surround objects in a gradient;" - " can possibly correctly separate unique but touching/adjacent objects." - ) - self.threshold_slider1.tooltips = ( - "All objects below this probability will be ignored (set to 0)" - ) - self.counter1.setToolTip( - "Will remove all objects smaller (in volume) than the specified number of pixels" - ) - self.threshold_slider2.tooltips = ( - "All seeds below this probability will be ignored (set to 0)" - ) - self.counter2.setToolTip( - "Will remove all seeds smaller (in volume) than the specified number of pixels" - ) - - def _show_watershed(self): - name = "Watershed" - if self.method_choice.currentText() == name: - - self._show_slider1() - self._show_slider2() - self._show_counter1() - self._show_counter2() - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_connected(self): - name = "Connected components" - if self.method_choice.currentText() == name: - - self._show_slider1() - self._show_slider2(False) - self._show_counter1() - self._show_counter2(False) - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_slider1(self, is_visible: bool = True): - self.threshold_slider1.container.setVisible(is_visible) - - def _show_slider2(self, is_visible: bool = True): - self.threshold_slider2.container.setVisible(is_visible) - - def _show_counter1(self, is_visible: bool = True): - self.counter1.setVisible(is_visible) - self.counter1.label.setVisible(is_visible) - - def _show_counter2(self, is_visible: bool = True): - self.counter2.setVisible(is_visible) - self.counter2.label.setVisible(is_visible) - - class ToInstanceUtils(BasePluginFolder): """ Widget to convert semantic labels to instance labels diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 99733936..7810a388 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -12,7 +12,11 @@ from napari_cellseg3d.code_models.model_framework import ModelFramework from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_plugins.plugin_convert import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -77,9 +81,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): config.InferenceWorkerConfig() ) """InferenceWorkerConfig class from config.py""" - self.instance_config: config.InstanceSegConfig = ( - config.InstanceSegConfig() - ) + self.instance_config: InstanceMethod """InstanceSegConfig class from config.py""" self.post_process_config: config.PostProcessConfig = ( config.PostProcessConfig() @@ -551,18 +553,9 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - instance_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.threshold_slider1.slider_value - ) - instance_small_object_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.counter1.value() - ) - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.method_choice.currentText(), - threshold=instance_thresh_config, - small_object_removal_threshold=instance_small_object_thresh_config, - ) + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] self.post_process_config = config.PostProcessConfig( zoom=zoom_config, diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index d77ef0c3..ac8aefc3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -791,7 +791,6 @@ def start(self): self.data = None raise err - model_config = config.ModelInfo( name=self.model_choice.currentText() ) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index eba01b07..f4b11c50 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,9 +8,6 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import binary_connected -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -18,6 +15,12 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -34,10 +37,6 @@ # "test" : DO NOT USE, reserved for testing } -INSTANCE_SEGMENTATION_METHOD_LIST = { - "Watershed": binary_watershed, - "Connected components": binary_connected, -} WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") @@ -121,21 +120,11 @@ class Zoom: zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: str = None - threshold: Thresholding = Thresholding(enabled=False, threshold_value=0.85) - small_object_removal_threshold: Thresholding = Thresholding( - enabled=True, threshold_value=20 - ) - - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 5ba4575d..586a9014 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -10,6 +10,7 @@ from qtpy import QtCore from qtpy.QtCore import QObject from qtpy.QtCore import Qt + # from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor @@ -500,9 +501,12 @@ def __init__( self._build_container() - def _build_container(self): - self.container.layout + def set_visibility(self, visible: bool): + self.container.setVisible(visible) + self.setVisible(visible) + self.text_label.setVisible(visible) + def _build_container(self): if self.text_label is not None: add_widgets( self.container.layout, @@ -1026,7 +1030,7 @@ class DoubleIncrementCounter(QDoubleSpinBox): def __init__( self, lower: Optional[float] = 0.0, - upper: Optional[float] = 10.0, + upper: Optional[float] = 1000.0, default: Optional[float] = 0.0, step: Optional[float] = 1.0, parent: Optional[QWidget] = None, @@ -1049,6 +1053,13 @@ def __init__( if label is not None: self.label = make_label(name=label) + self.valueChanged.connect(self._update_step) + + def _update_step(self): + if self.value() < 0.9: + self.setSingleStep(0.1) + else: + self.setSingleStep(1) @property def tooltips(self): @@ -1085,6 +1096,10 @@ def make_n( cls, n, lower, upper, default, step, parent, fixed ) + def set_visibility(self, visible: bool): + self.setVisible(visible) + self.label.setVisible(visible) + class IntIncrementCounter(QSpinBox): """Class implementing a number counter with increments (spin box) for int.""" diff --git a/requirements.txt b/requirements.txt index 3ba73405..739b7aa3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 005b71da231c66a8911a213126e939c18dcaac83 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 002/577] Disabled small removal in Voronoi-Otsu --- .../code_models/model_instance_seg.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 8b39ad5f..087b36b5 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -102,7 +102,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - remove_small_size: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -112,11 +112,12 @@ def voronoi_otsu( volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - remove_small_size (float): remove all objects smaller than the specified size in pixels + Returns: Instance segmentation labels from Voronoi-Otsu method """ + # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma @@ -425,7 +426,7 @@ def __init__(self): name="Voronoi-Otsu", function=voronoi_otsu, num_sliders=0, - num_counters=3, + num_counters=2, ) self.counters[0].label.setText("Spot sigma") self.counters[ @@ -441,18 +442,19 @@ def __init__(self): self.counters[1].setMaximum(100) self.counters[1].setValue(2) - self.counters[2].label.setText("Small object removal") - self.counters[2].tooltips = ( - "Volume/size threshold for small object removal." - "\nAll objects with a volume/size below this value will be removed." - ) + # self.counters[2].label.setText("Small object removal") + # self.counters[2].tooltips = ( + # "Volume/size threshold for small object removal." + # "\nAll objects with a volume/size below this value will be removed." + # ) + # self.counters[2].setValue(30) def run_method(self, image): return self.function( image, self.counters[0].value(), self.counters[1].value(), - self.counters[2].value(), + # self.counters[2].value(), ) From dd6cf0e50b90f417a8e85b053e05115c1d53ef32 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 003/577] Added new docs for instance seg --- docs/res/code/model_instance_seg.rst | 23 +++++++++++++++++++ .../code_models/model_instance_seg.py | 22 ++++++++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst index e4146ec1..3b323173 100644 --- a/docs/res/code/model_instance_seg.rst +++ b/docs/res/code/model_instance_seg.rst @@ -1,6 +1,29 @@ model_instance_seg.py =========================================== +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu + :members: __init__ + Functions ------------- diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 087b36b5..517643a4 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -34,6 +34,14 @@ def __init__( num_sliders: int, num_counters: int, ): + """ + Methods for instance segmentation + Args: + name: Name of the instance segmentation method (for UI) + function: Function to use for instance segmentation + num_sliders: Number of Slider UI elements needed to set the parameters of the function + num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + """ self.name = name self.function = function self.counters: List[ui.DoubleIncrementCounter] = [] @@ -174,7 +182,7 @@ def binary_watershed( Note: This function uses the `skimage.segmentation.watershed `_ - function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays. + function that converts the input image into ``np.float64`` data type for processing. Therefore, please make sure enough memory is allocated when handling large arrays. Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. @@ -350,6 +358,8 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod, metaclass=Singleton): + """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" + def __init__(self): super().__init__( name="Watershed", @@ -393,6 +403,8 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod, metaclass=Singleton): + """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" + def __init__(self): super().__init__( name="Connected Components", @@ -421,6 +433,8 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" + def __init__(self): super().__init__( name="Voronoi-Otsu", @@ -428,14 +442,14 @@ def __init__(self): num_sliders=0, num_counters=2, ) - self.counters[0].label.setText("Spot sigma") + self.counters[0].label.setText("Spot sigma") # closeness self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -529,7 +543,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { + VoronoiOtsu().name: VoronoiOtsu, Watershed().name: Watershed, ConnectedComponents().name: ConnectedComponents, - VoronoiOtsu().name: VoronoiOtsu, } From 6b8fd4c6241112149e96937475bb85d1647d315e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 09:50:45 +0100 Subject: [PATCH 004/577] Docs + UI update - Updated welcome/README - Changed step for DoubleCounter --- README.md | 5 +++-- docs/res/welcome.rst | 15 +++++++++------ napari_cellseg3d/interface.py | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 933b164e..c14e1a2c 100644 --- a/README.md +++ b/README.md @@ -123,8 +123,9 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). - +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). +Please refer to the documentation for full acknowledgements. ## Plugin base This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 6832e71e..d2f2c0f0 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -90,20 +90,23 @@ We also provide a model that was trained in-house on mesoSPIM nuclei data in col This plugin mainly uses the following libraries and software: -* `napari website`_ +* `napari`_ -* `PyTorch website`_ +* `PyTorch`_ -* `MONAI project website`_ (various models used here are credited `on their website`_) +* `MONAI project`_ (various models used here are credited `on their website`_) + +* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ .. _Wyss Center: https://wysscenter.ch/ .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP -.. _napari website: https://napari.org/ -.. _PyTorch website: https://pytorch.org/ -.. _MONAI project website: https://monai.io/ +.. _napari: https://napari.org/ +.. _PyTorch: https://pytorch.org/ +.. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets +.. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype .. rubric:: References diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 586a9014..2e5746ba 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1057,9 +1057,9 @@ def __init__( def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.1) + self.setSingleStep(0.01) else: - self.setSingleStep(1) + self.setSingleStep(0.1) @property def tooltips(self): From d291f6faf5c88050fa463fd5594f6426939ee699 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:07:33 +0100 Subject: [PATCH 005/577] Update requirements.txt Fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 739b7aa3..ead0052c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pyclesperanto-prototype>=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From aa9ad616a76350979b47a148adaf0893e8b3b3b1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:11:22 +0100 Subject: [PATCH 006/577] Update setup.cfg --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 6261d576..37feca98 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ install_requires = itk tqdm nibabel + pyclesperanto-prototype scikit-image pillow tqdm From 96e5bb9300516b5833c88b97f3b131ba4b492218 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 007/577] isort --- napari_cellseg3d/_tests/fixtures.py | 1 + napari_cellseg3d/_tests/test_plugin_inference.py | 7 ++++--- napari_cellseg3d/_tests/test_training.py | 4 ++-- napari_cellseg3d/_tests/test_weight_download.py | 7 +++---- napari_cellseg3d/code_models/model_instance_seg.py | 4 ++-- napari_cellseg3d/code_plugins/plugin_convert.py | 2 +- .../code_plugins/plugin_model_inference.py | 8 ++++---- napari_cellseg3d/config.py | 11 +++++------ napari_cellseg3d/interface.py | 4 ++-- 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index 8fcf56db..b40a77d3 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -1,4 +1,5 @@ import warnings + from qtpy.QtWidgets import QTextEdit diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1ec7e77e..5b89c065 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -1,10 +1,11 @@ -from tifffile import imread from pathlib import Path -from napari_cellseg3d.config import MODEL_LIST +from tifffile import imread + from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.config import MODEL_LIST def test_inference(make_napari_viewer, qtbot): diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 70b79b31..21731ba1 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -1,10 +1,10 @@ from pathlib import Path from napari_cellseg3d import config -from napari_cellseg3d.code_plugins.plugin_model_training import Trainer from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.config import MODEL_LIST from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d.config import MODEL_LIST def test_training(make_napari_viewer, qtbot): diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index 306fdf6c..b8f0d748 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,7 +1,6 @@ -from napari_cellseg3d.code_models.model_workers import ( - WeightsDownloader, - WEIGHTS_DIR, -) +from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import WeightsDownloader + # DISABLED, causes GitHub actions to freeze def test_weight_download(): diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 517643a4..f3fb2a2f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -7,11 +7,11 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget +from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.filters import thresholding from skimage.transform import resize # from skimage.measure import mesh_surface_area @@ -20,8 +20,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 08d209b9..838d505e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -12,9 +12,9 @@ from napari_cellseg3d import config from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 7810a388..3c7ef54f 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,13 +10,13 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import ( INSTANCE_SEGMENTATION_METHOD_LIST, ) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index f4b11c50..65129f64 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,6 +8,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -15,12 +20,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 2e5746ba..e6db6930 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -8,10 +8,10 @@ # Qt from qtpy import QtCore -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt # from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices From 4e38e634b549752330385804f7431cbbc163c3c4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 008/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 napari_cellseg3d/_tests/pytest.ini diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini new file mode 100644 index 00000000..814cca2e --- /dev/null +++ b/napari_cellseg3d/_tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +qt_api=pyqt5 \ No newline at end of file From 88145e75240bee80ca8eb4953d4c6eee477382f3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 009/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- .../code_models/model_instance_seg.py | 84 +++++++++++-------- .../code_plugins/plugin_convert.py | 2 +- .../code_plugins/plugin_model_inference.py | 2 +- 3 files changed, 49 insertions(+), 39 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index f3fb2a2f..32aa474b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -20,11 +20,16 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from napari_cellseg3d.utils import sphericity_volume_area +# list of methods : +WATERSHED = "Watershed" +CONNECTED_COMP = "Connected Components" +VORONOI_OTSU = "Voronoi-Otsu" + class InstanceMethod: def __init__( @@ -33,6 +38,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, + widget_parent: QWidget = None ): """ Methods for instance segmentation @@ -41,6 +47,7 @@ def __init__( function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + widget_parent: parent for the declared widgets """ self.name = name self.function = function @@ -52,7 +59,7 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), ) self.sliders.append(getattr(self, widget)) @@ -62,7 +69,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label=""), + ui.DoubleIncrementCounter(label="", parent=None), ) self.counters.append(getattr(self, widget)) @@ -357,15 +364,16 @@ def fill(lst, n=len(properties) - 1): ) -class Watershed(InstanceMethod, metaclass=Singleton): +class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Watershed", + name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -402,15 +410,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): +class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Connected Components", + name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -432,15 +441,16 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): +class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self): + def __init__(self, widget_parent): super().__init__( - name="Voronoi-Otsu", + name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, num_counters=2, + widget_parent=widget_parent ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ @@ -485,7 +495,6 @@ def __init__(self, parent=None): parent: parent widget """ super().__init__(parent) - self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) @@ -496,37 +505,38 @@ def __init__(self, parent=None): self._build() def _build(self): - group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - self.instance_widgets[name] = [] - if len(method().sliders) > 0: - for slider in method().sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method().counters) > 0: - for counter in method().counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) + try: + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + method_class = method(widget_parent=self.parent()) + self.instance_widgets[name] = [] + # moderately unsafe way to init those widgets + if len(method_class.sliders) > 0: + for slider in method_class.sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method_class.counters) > 0: + for counter in method_class.counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + except RuntimeError as e: + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() - - for widget in self.instance_widgets[method.name]: - widget.set_visibility(True) - for key in self.instance_widgets.keys(): - if key != method.name: - for widget in self.instance_widgets[key]: + for name in self.instance_widgets.keys(): + if name != self.method_choice.currentText(): + for widget in self.instance_widgets[name]: widget.set_visibility(False) + else: + for widget in self.instance_widgets[name]: + widget.set_visibility(True) def run_method(self, volume): """ @@ -543,7 +553,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VoronoiOtsu().name: VoronoiOtsu, - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 838d505e..ebcff9d7 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -384,7 +384,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.data_panel = self._build_io_panel() self.label_layer_loader.set_layer_type(napari.layers.Layer) - self.instance_widgets = InstanceWidgets() + self.instance_widgets = InstanceWidgets(parent=self) self.start_btn = ui.Button("Start", self._start) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 3c7ef54f..604d3e78 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -191,7 +191,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## ################## # instance segmentation widgets - self.instance_widgets = InstanceWidgets(self) + self.instance_widgets = InstanceWidgets(parent=self) self.use_instance_choice = ui.CheckBox( "Run instance segmentation", func=self._toggle_display_instance From d9f2e0a050842d4eac34a4cf7acd88dacb1cc4c8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 010/577] Fix inference --- .../code_models/model_instance_seg.py | 5 +- napari_cellseg3d/code_models/model_workers.py | 12 ++--- .../code_plugins/plugin_model_inference.py | 13 ++--- napari_cellseg3d/config.py | 6 ++- notebooks/assess_instance.ipynb | 50 +++++++++++++++++++ 5 files changed, 71 insertions(+), 15 deletions(-) create mode 100644 notebooks/assess_instance.ipynb diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 32aa474b..af9d0af8 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -498,8 +498,10 @@ def __init__(self, parent=None): self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = [] + self.methods = {} + """Contains the instance of the method, with its name as key""" self.instance_widgets = {} + """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() @@ -511,6 +513,7 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class self.instance_widgets[name] = [] # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 50116748..ad0b447e 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -313,9 +313,7 @@ def log_parameters(self): instance_config = config.post_process_config.instance if instance_config.enabled: self.log( - f"Instance segmentation enabled, method : {instance_config.method}\n" - f"Probability threshold is {instance_config.threshold.threshold_value:.2f}\n" - f"Objects smaller than {instance_config.small_object_removal_threshold.threshold_value} pixels will be removed\n" + f"Instance segmentation enabled, method : {instance_config.method.name}\n" ) self.log("-" * 20) @@ -388,7 +386,7 @@ def load_folder(self): return inference_loader def load_layer(self): - self.log("Loading layer\n") + self.log("\nLoading layer\n") data = np.squeeze(self.config.layer) volume = np.array(data, dtype=np.int16) @@ -553,7 +551,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -609,8 +607,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 604d3e78..2ad7371c 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -553,9 +553,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -727,13 +728,13 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method = self.worker_config.post_process_config.instance.method + method_name = self.worker_config.post_process_config.instance.method.name number_cells = ( np.unique(labels.flatten()).size - 1 ) # remove background - name = f"({number_cells} objects)_{method}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" instance_layer = viewer.add_labels(labels, name=name) @@ -748,7 +749,7 @@ def on_yield(self, result: InferenceResult): f"Number of instances : {stats.number_objects}" ) - csv_name = f"/{method}_seg_results_{image_id}_{utils.get_date_time()}.csv" + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" stats_df.to_csv( self.worker_config.results_path + csv_name, index=False, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 65129f64..9f94ff1f 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -118,12 +118,16 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb new file mode 100644 index 00000000..40412282 --- /dev/null +++ b/notebooks/assess_instance.ipynb @@ -0,0 +1,50 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tifffile import imread\n", + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From a25f18f66363d2b039db2f02923f0cf431a7ad83 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 011/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/model_instance_seg.py | 10 +- napari_cellseg3d/code_plugins/plugin_crop.py | 4 +- .../code_plugins/plugin_utilities.py | 7 +- .../dev_scripts/artefact_labeling.py | 421 ++++++++++++++++++ .../dev_scripts/correct_labels.py | 320 +++++++++++++ .../dev_scripts/evaluate_labels.py | 276 ++++++++++++ notebooks/assess_instance.ipynb | 401 ++++++++++++++++- 7 files changed, 1421 insertions(+), 18 deletions(-) create mode 100644 napari_cellseg3d/dev_scripts/artefact_labeling.py create mode 100644 napari_cellseg3d/dev_scripts/correct_labels.py create mode 100644 napari_cellseg3d/dev_scripts/evaluate_labels.py diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index af9d0af8..6c8ee74f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -42,12 +42,14 @@ def __init__( ): """ Methods for instance segmentation + Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -123,14 +125,15 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - Returns: Instance segmentation labels from Voronoi-Otsu method + """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) @@ -154,6 +157,7 @@ def binary_connected( thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) + """ semantic = np.squeeze(volume) foreground = semantic > thres # int(255 * thres) @@ -197,6 +201,7 @@ def binary_watershed( thres_seeding (float): threshold for seeding. Default: 0.9 thres_small (int): size threshold of small objects removal. Default: 10 rem_seed_thres (int): threshold for small seeds removal. Default : 3 + """ semantic = np.squeeze(volume) seed_map = semantic > thres_seeding @@ -493,6 +498,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -544,10 +550,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = INSTANCE_SEGMENTATION_METHOD_LIST[ self.method_choice.currentText() diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 23ba190c..153b5e69 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -176,8 +176,8 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 400]) - self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Expanding) + ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() # def _check_results_path(self, folder): diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 6c726c25..ad1f5547 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -59,10 +59,10 @@ def _build(self): layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) - layout.setSizeConstraint(QLayout.SetFixedSize) + # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) - self.setMinimumHeight(1000) - self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) + # self.setMinimumHeight(2000) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._update_visibility() def _create_utils_widgets(self, names): @@ -78,7 +78,6 @@ def _create_utils_widgets(self, names): raise RuntimeError( "One or several utility widgets are missing/erroneous" ) - # TODO how to auto-update list based on UTILITIES_WIDGETS ? def _update_visibility(self): widget_class = UTILITIES_WIDGETS[self.utils_choice.currentText()] diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py new file mode 100644 index 00000000..875ca9b6 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -0,0 +1,421 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +from pathlib import Path +import scipy.ndimage as ndimage +import os +import napari +# import sys +# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from skimage.filters import threshold_otsu + +""" +New code by Yves Paychere +Creates labels of artifacts in an image based on existing labels of neurons +""" + + +def map_labels(labels, artefacts): + """Map the artefacts labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + artefacts : ndarray + Label image with artefacts labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the artefact and the label value of the neurone associated or the neurons associated + new_labels: list + The labels of the artefacts that are not labelled in the neurons + """ + map_labels_existing = [] + new_labels = [] + + for i in np.unique(artefacts): + if i == 0: + continue + indexes = labels[artefacts == i] + # find the most common label in the indexes + unique, counts = np.unique(indexes, return_counts=True) + unique = np.flip(unique[np.argsort(counts)]) + counts = np.flip(counts[np.argsort(counts)]) + if unique[0] != 0: + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + elif ( + counts[0] < np.sum(counts) * 2 / 3.0 + ): # the artefact is connected to multiple neurons + total = 0 + ii = 1 + while total < np.size(indexes) / 3.0: + total = np.sum(counts[1 : ii + 1]) + ii += 1 + map_labels_existing.append(np.append([i], unique[1 : ii + 1])) + else: + new_labels.append(i) + + return map_labels_existing, new_labels + + +def make_labels( + path_image, + path_labels_out, + threshold_factor=1, + threshold_size=30, + label_value=1, + do_multi_label=True, + use_watershed=True, + augment_contrast_factor=2, +): + """Detect nucleus. using a binary watershed algorithm and otsu thresholding. + Parameters + ---------- + path_image : str + Path to image. + path_labels_out : str + Path of the output labelled image. + threshold_size : int, optional + Threshold for nucleus size, if the nucleus is smaller than this value it will be removed. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different nucleus will be labelled as a different value. + use_watershed : bool, optional + If True, use watershed algorithm to detect nucleus. + augment_contrast_factor : int, optional + Factor to augment the contrast of the image. + Returns + ------- + ndarray + Label image with nucleus labelled with 1 value per nucleus. + """ + + image = imread(path_image) + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + threshold_brightness = threshold_otsu(image) * threshold_factor + image_contrasted = np.where(image > threshold_brightness, image, 0) + + if use_watershed: + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = image_contrasted * augment_contrast_factor + image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) + labels = binary_watershed(image_contrasted, thres_small=threshold_size) + else: + labels = ndimage.label(image_contrasted)[0] + + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + + if not do_multi_label: + labels = np.where(labels > 0, label_value, 0) + + imwrite(path_labels_out, labels.astype(np.uint16)) + imwrite( + path_labels_out.replace(".tif", "_contrast.tif"), + image_contrasted.astype(np.float32), + ) + + +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): + """Select image by labels. + Parameters + ---------- + path_image : str + Path to image. + path_labels : str + Path to labels. + path_image_out : str + Path of the output image. + label_values : list + List of label values to select. + """ + image = imread(path_image) + labels = imread(path_labels) + image = np.where(np.isin(labels, label_values), image, 0) + imwrite(path_image_out, image.astype(np.float32)) + + +# select the smalles cube that contains all the none zero pixel of an 3d image +def get_bounding_box(img): + height = np.any(img, axis=(0, 1)) + rows = np.any(img, axis=(0, 2)) + cols = np.any(img, axis=(1, 2)) + + xmin, xmax = np.where(cols)[0][[0, -1]] + ymin, ymax = np.where(rows)[0][[0, -1]] + zmin, zmax = np.where(height)[0][[0, -1]] + return xmin, xmax, ymin, ymax, zmin, zmax + + +# crop the image +def crop_image(img): + xmin, xmax, ymin, ymax, zmin, zmax = get_bounding_box(img) + return img[xmin:xmax, ymin:ymax, zmin:zmax] + + +def crop_image_path(path_image, path_image_out): + """Crop image. + Parameters + ---------- + path_image : str + Path to image. + path_image_out : str + Path of the output image. + """ + image = imread(path_image) + image = crop_image(image) + imwrite(path_image_out, image.astype(np.float32)) + + +def make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, + label_value=2, + do_multi_label=False, + remove_true_labels=True, +): + """Detect pseudo nucleus. + Parameters + ---------- + image : ndarray + Image. + labels : ndarray + Label image. + threshold_artefact_brightness_percent : int, optional + Threshold for artefact brightness. + threshold_artefact_size_percent : int, optional + Threshold for artefact size, if the artefcact is smaller than this percentage of the neurons it will be removed. + contrast_power : int, optional + Power for contrast enhancement. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different artefact will be labelled as a different value. + remove_true_labels : bool, optional + If True, the true labels will be removed from the artefacts. + Returns + ------- + ndarray + Label image with pseudo nucleus labelled with 1 value per artefact. + """ + + neurons = np.array(labels > 0) + non_neurons = np.array(labels == 0) + + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + # calculate the percentile of the intensity of all the pixels that are labeled as neurons + # check if the neurons are not empty + if np.sum(neurons) > 0: + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + else: + # take the percentile of the non neurons if the neurons are empty + threshold = np.percentile(image[non_neurons], 90) + + # modify the contrast of the image accoring to the threshold with a tanh function and map the values to [0,1] + + image_contrasted = np.tanh((image - threshold) * contrast_power) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) + + artefacts = binary_watershed( + image_contrasted, thres_seeding=0.95, thres_small=15, thres_objects=0.4 + ) + + if remove_true_labels: + # evaluate where the artefacts are connected to the neurons + # map the artefacts label to the neurons label + map_labels_existing, new_labels = map_labels(labels, artefacts) + + # remove the artefacts that are connected to the neurons + for i in map_labels_existing: + artefacts[artefacts == i[0]] = 0 + # remove all the pixels of the neurons from the artefacts + artefacts = np.where(labels > 0, 0, artefacts) + + # remove the artefacts that are too small + # calculate the percentile of the size of the neurons + if np.sum(neurons) > 0: + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + else: + # find the size of each connected component + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + # remove the smallest connected components + neurone_size_percentile = np.percentile(sizes, 95) + + # select the artefacts that are bigger than the percentile + + artefacts = select_artefacts_by_size( + artefacts, min_size=neurone_size_percentile, is_labeled=True + ) + + # relabel with the label value if the artefacts are not multi label + if not do_multi_label: + artefacts = np.where(artefacts > 0, label_value, artefacts) + + return artefacts + + +def select_artefacts_by_size(artefacts, min_size, is_labeled=False): + """Select artefacts by size. + Parameters + ---------- + artefacts : ndarray + Label image with artefacts labelled as 1. + min_size : int, optional + Minimum size of artefacts to keep + is_labeled : bool, optional + If True, the artefacts are already labelled. + Returns + ------- + ndarray + Label image with artefacts labelled and small artefacts removed. + """ + if not is_labeled: + # find all the connected components in the artefacts image + labels = ndimage.label(artefacts)[0] + else: + labels = artefacts + + # remove the small components + labels_i, counts = np.unique(labels, return_counts=True) + labels_i = labels_i[counts > min_size] + labels_i = labels_i[labels_i > 0] + artefacts = np.where(np.isin(labels, labels_i), labels, 0) + return artefacts + + +def create_artefact_labels( + image_path, + labels_path, + output_path, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. + Parameters + ---------- + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. + output_path : str + Path to save the output label image file. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + image = imread(image_path) + labels = imread(labels_path) + + artefacts = make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power=contrast_power, + label_value=2, + do_multi_label=False, + ) + + neurons_artefacts_labels = np.where(labels > 0, 1, artefacts) + imwrite(output_path, neurons_artefacts_labels) + + +def visualize_images(paths): + """Visualize images. + Parameters + ---------- + paths : list + List of paths to images to visualize. + """ + viewer = napari.Viewer(ndisplay=3) + for path in paths: + viewer.add_image(imread(path), name=os.path.basename(path)) + # wait for the user to close the viewer + napari.run() + + +def create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1 for all images in a folder. The images created are stored in a folder artefact_neurons. + Parameters + ---------- + path : str + Path to folder with images in folder volumes and labels in folder lab_sem. The images are expected to have the same alphabetical order in both folders. + do_visualize : bool, optional + If True, the images will be visualized. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + # find all the images in the folder and create a list + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + # sort the list + path_labels.sort() + path_images.sort() + # create the output folder + os.makedirs(path + "/artefact_neurons", exist_ok=True) + # create the artefact labels + for i in range(len(path_images)): + print(path_labels[i]) + # consider that the images and the labels have names in the same alphabetical order + create_artefact_labels( + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power, + ) + if do_visualize: + visualize_images( + [ + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + ] + ) + + +if __name__ == "__main__": + + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py new file mode 100644 index 00000000..f94327e2 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -0,0 +1,320 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings +from napari.qt.threading import thread_worker +from tqdm import tqdm +import threading +# import sys +# sys.path.append(str(Path(__file__) / "../../")) + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +""" +New code by Yves Paychère +Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold +""" + + +def relabel_non_unique_i(label, save_path, go_fast=False): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + label : np.array + the label image + save_path : str + the path to save the relabeld image + """ + value_label = 0 + new_labels = np.zeros_like(label) + map_labels_existing = [] + unique_label = np.unique(label) + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + i = unique_label[i_label] + if i == 0: + continue + if go_fast: + new_label, to_add = ndimage.label(label == i) + map_labels_existing.append( + [i, list(range(value_label + 1, value_label + to_add + 1))] + ) + + else: + # catch the warning of the watershed + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_label = binary_watershed(label == i) + unique = np.unique(new_label) + to_add = unique[-1] + map_labels_existing.append([i, unique[1:] + value_label]) + + new_label[new_label != 0] += value_label + new_labels += new_label + value_label += to_add + + imwrite(save_path, new_labels) + return map_labels_existing + + +def add_label(old_label, artefact, new_label_path, i_labels_to_add): + """add the label to the label image + Parameters + ---------- + old_label : np.array + the label image + artefact : np.array + the artefact image that contains some neurons + new_label_path : str + the path to save the new label image + """ + new_label = old_label.copy() + max_label = np.max(old_label) + for i, i_label in enumerate(i_labels_to_add): + new_label[artefact == i_label] = i + max_label + 1 + imwrite(new_label_path, new_label) + + +returns = [] + + +def ask_labels(unique_artefact): + global returns + returns = [] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + + if i_labels_to_add_tmp == [0]: + print("no label added") + returns = [[]] + print("close the napari window to continue") + return + + for i in i_labels_to_add_tmp: + if i == 0: + print("0 is not a valid label") + # delete the 0 + i_labels_to_add_tmp.remove(i) + # test if all index are negative + if all(i < 0 for i in i_labels_to_add_tmp): + print( + "all labels are negative-> will add all the labels except the one you gave" + ) + i_labels_to_add = list(unique_artefact) + for i in i_labels_to_add_tmp: + if np.abs(i) in i_labels_to_add: + i_labels_to_add.remove(np.abs(i)) + else: + print("the label", np.abs(i), "is not in the label image") + i_labels_to_add_tmp = i_labels_to_add + else: + # remove the negative index + for i in i_labels_to_add_tmp: + if i < 0: + i_labels_to_add_tmp.remove(i) + print( + "ignore the negative label", + i, + " since not all the labels are negative", + ) + if i not in unique_artefact: + print("the label", i, "is not in the label image") + i_labels_to_add_tmp.remove(i) + + returns = [i_labels_to_add_tmp] + print("close the napari window to continue") + + +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + image_path : str + the path to the image + label_path : str + the path to the label image + go_fast : bool, optional + if True, the relabeling will be faster but the labels can more frequently be merged, by default False + check_for_unicity : bool, optional + if True, the relabeling will check if the labels are unique, by default True + delay : float, optional + the delay between each image for the visualization, by default 0.3 + """ + global returns + + label = imread(label_path) + initial_label_path = label_path + if check_for_unicity: + # check if the label are unique + new_label_path = label_path[:-4] + "_relabel_unique.tif" + map_labels_existing = relabel_non_unique_i( + label, new_label_path, go_fast=go_fast + ) + print( + "visualize the relabeld image in white the previous labels and in red the new labels" + ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + label_path = new_label_path + # detect artefact + print("detection of potential neurons (in progress)") + image = imread(image_path) + artefact = make_artefact_labels.make_artefact_labels( + image, + imread(label_path), + do_multi_label=True, + threshold_artefact_brightness_percent=30, + threshold_artefact_size_percent=0, + contrast_power=30, + ) + print("detection of potential neurons (done)") + # ask the user if the artefact are not neurons + i_labels_to_add = [] + loop = True + unique_artefact = list(np.unique(artefact)) + while loop: + # visualize the artefact and ask the user which label to add to the label image + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t.start() + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="potential neurons") + viewer.add_labels(imread(label_path), name="labels") + napari.run() + t.join() + i_labels_to_add_tmp = returns[0] + # check if the selected labels are neurones + for i in i_labels_to_add: + if i not in i_labels_to_add_tmp: + i_labels_to_add_tmp.append(i) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + print("these labels will be added") + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if revert != "y": + i_labels_to_add = i_labels_to_add_tmp + for i in i_labels_to_add: + if i in unique_artefact: + unique_artefact.remove(i) + loop = input("Do you want to add more labels? (y/n)") == "y" + # add the label to the label image + new_label_path = initial_label_path[:-4] + "_new_label.tif" + print("the new label will be saved in", new_label_path) + add_label(imread(label_path), artefact, new_label_path, i_labels_to_add) + # store the artefact remaining + new_artefact_path = initial_label_path[:-4] + "_artefact.tif" + artefact = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + imwrite(new_artefact_path, artefact) + + +def modify_viewer(old_label, new_label, args): + """modify the viewer to show the relabeling + Parameters + ---------- + old_label : napari.layers.Labels + the layer of the old label + new_label : napari.layers.Labels + the layer of the new label + args : list + the first element is the old label and the second element is the new label + """ + if args == "hide new label": + new_label.visible = False + elif args == "show new label": + new_label.visible = True + else: + old_label.selected_label = args[0] + if not np.isnan(args[1]): + new_label.selected_label = args[1] + + +@thread_worker +def to_show(map_labels_existing, delay=0.5): + """modify the viewer to show the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the of the map between the old label and the new label + delay : float, optional + the delay between each image for the visualization, by default 0.3 + """ + time.sleep(2) + for i in map_labels_existing: + yield "hide new label" + if len(i[1]): + yield [i[0], i[1][0]] + else: + yield [i[0], np.nan] + time.sleep(delay) + yield "show new label" + for j in i[1]: + yield [i[0], j] + time.sleep(delay) + + +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): + """Builds a widget that can control a function in another thread.""" + + worker = to_show(map_labels_existing, delay) + worker.start() + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + + +def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): + """visualize the map of the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the relabeling + """ + label = imread(label_path) + relabel = imread(relabel_path) + + viewer = napari.Viewer(ndisplay=3) + + old_label = viewer.add_labels(label, num_colors=3) + new_label = viewer.add_labels(relabel, num_colors=3) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + + # viewer.dims.ndisplay = 3 + viewer.camera.angles = (180, 3, 50) + viewer.camera.zoom = 1 + + old_label.show_selected_label = True + new_label.show_selected_label = True + + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + napari.run() + + +def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + folder_path : str + the path to the folder containing the label images + end_of_new_name : str + thename to add at the end of the relabled image + """ + for file in Path.iterdir(folder_path): + if file.suffix == ".tif": + label = imread(str(Path(folder_path / file))) + relabel_non_unique_i( + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + ) + + +if __name__ == "__main__": + + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py new file mode 100644 index 00000000..857bcd19 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -0,0 +1,276 @@ +import numpy as np +import pandas as pd +from tqdm import tqdm +import napari + +from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels + + +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): + """Evaluate the model performance. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + do_print : bool + If True, print the results. + Returns + ------- + neuron_found : float + The number of neurons found by the model + neuron_fused: float + The number of neurons fused by the model + neuron_not_found: float + The number of neurons not found by the model + neuron_artefact: float + The number of artefact that the model wrongly labelled as neurons + mean_true_positive_ratio_model: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the model's label) + mean_ratio_pixel_found: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_ratio_pixel_found_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_true_positive_ratio_model_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels in any fused neurons of this model's label)/(total number of pixels of the model's label) + mean_ratio_false_pixel_artefact: float + The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) + """ + log.debug("Mapping labels...") + map_labels_existing, map_fused_neurons, new_labels = map_labels( + labels, model_labels + ) + + # calculate the number of neurons individually found + neurons_found = len(map_labels_existing) + # calculate the number of neurons fused + neurons_fused = len(map_fused_neurons) + # calculate the number of neurons not found + log.debug("Calculating the number of neurons not found...") + neurons_found_labels = np.unique( + [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] + ) + unique_labels = np.unique(labels) + neurons_not_found = len(unique_labels) - 1 - len(neurons_found_labels) + # artefacts found + artefacts_found = len(new_labels) + if len(map_labels_existing) > 0: + # calculate the mean true positive ratio of the model + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + # calculate the mean ratio of the neurons pixels correctly labelled + mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) + else: + mean_true_positive_ratio_model = np.nan + mean_ratio_pixel_found = np.nan + + if len(map_fused_neurons) > 0: + # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + # calculate the mean true positive ratio of the model for the fused neurons + mean_true_positive_ratio_model_fused = np.mean( + [i[3] for i in map_fused_neurons] + ) + else: + mean_ratio_pixel_found_fused = np.nan + mean_true_positive_ratio_model_fused = np.nan + + # calculate the mean false positive ratio of each artefact + if len(new_labels) > 0: + mean_ratio_false_pixel_artefact = np.mean([i[1] for i in new_labels]) + else: + mean_ratio_false_pixel_artefact = np.nan + + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( + "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, + ) + print( + "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, + ) + print( + "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, + ) + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + ) + if visualize: + viewer = napari.Viewer() + viewer.add_labels(labels, name="ground truth") + viewer.add_labels(model_labels, name="model's labels") + found_model = np.where( + np.isin(model_labels, [i[0] for i in map_labels_existing]), + model_labels, + 0, + ) + viewer.add_labels(found_model, name="model's labels found") + found_label = np.where( + np.isin(labels, [i[1] for i in map_labels_existing]), labels, 0 + ) + viewer.add_labels(found_label, name="ground truth found") + neurones_not_found_labels = np.where( + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + ) + neurones_not_found_labels = neurones_not_found_labels[ + neurones_not_found_labels != 0 + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + viewer.add_labels(not_found, name="ground truth not found") + artefacts_found = np.where( + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + ) + viewer.add_labels(artefacts_found, name="model's labels artefacts") + fused_model = np.where( + np.isin(model_labels, [i[0] for i in map_fused_neurons]), + model_labels, + 0, + ) + viewer.add_labels(fused_model, name="model's labels fused") + fused_label = np.where( + np.isin(labels, [i[1] for i in map_fused_neurons]), labels, 0 + ) + viewer.add_labels(fused_label, name="ground truth fused") + napari.run() + + return ( + neurons_found, + neurons_fused, + neurons_not_found, + artefacts_found, + mean_true_positive_ratio_model, + mean_ratio_pixel_found, + mean_ratio_pixel_found_fused, + mean_true_positive_ratio_model_fused, + mean_ratio_false_pixel_artefact, + ) + + +def save_as_csv(results, path): + """ + Save the results as a csv file + + Parameters + ---------- + results: list + The results of the evaluation + path: str + The path to save the csv file + """ + print(np.array(results).shape) + df = pd.DataFrame( + [results], + columns=[ + "neurons_found", + "neurons_fused", + "neurons_not_found", + "artefacts_found", + "mean_true_positive_ratio_model", + "mean_ratio_pixel_found", + "mean_ratio_pixel_found_fused", + "mean_true_positive_ratio_model_fused", + "mean_ratio_false_pixel_artefact", + ], + ) + df.to_csv(path, index=False) + + +# if __name__ == "__main__": +# """ +# # Example of how to use the functions in this module. +# a = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# b = np.array([[5, 5, 0, 0], [5, 5, 2, 0], [0, 2, 2, 0], [0, 0, 2, 0]]) +# evaluate_model_performance(a, b) +# +# c = np.array([[2, 2, 0, 0], [2, 2, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# d = np.array([[4, 0, 4, 0], [4, 4, 4, 0], [0, 4, 4, 0], [0, 0, 4, 0]]) +# +# evaluate_model_performance(c, d) +# +# from tifffile import imread +# labels=imread("dataset/visual_tif/labels/testing_im_new_label.tif") +# labels_model=imread("dataset/visual_tif/artefact_neurones/basic_model.tif") +# evaluate_model_performance(labels, labels_model,visualize=True) +# """ +# from tifffile import imread +# +# labels = imread("dataset_clean/VALIDATION/validation_labels.tif") +# try: +# labels_model = imread("results/watershed_based_model/instance_labels.tif") +# except: +# raise Exception( +# "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" +# ) +# +# evaluate_model_performance(labels, labels_model, visualize=True) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From f6ce9073660bdf3910e18ef60c5892b42b39f1f2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 012/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../dev_scripts/evaluate_labels.py | 22 +- notebooks/assess_instance.ipynb | 408 ++++++++++++------ 2 files changed, 301 insertions(+), 129 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..b4436ccb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -4,6 +4,7 @@ import napari from napari_cellseg3d.utils import LOGGER as log + def map_labels(labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -33,10 +34,12 @@ def map_labels(labels, model_labels): unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + print(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + print(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -50,8 +53,7 @@ def map_labels(labels, model_labels): tmp_map.append( [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,12 +61,14 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # print(f"map_labels_existing: {map_labels_existing}") + print(f"map_fused_neurons: {map_fused_neurons}") + # print(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels @@ -99,7 +103,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -109,7 +113,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..6e6a9b5f 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -111,17 +111,274 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "i: 1\n", + "unique: 1\n", + "i: 2\n", + "unique: 2\n", + "i: 3\n", + "unique: 3\n", + "i: 4\n", + "unique: 4\n", + "i: 5\n", + "unique: 5\n", + "i: 6\n", + "unique: 6\n", + "i: 7\n", + "unique: 7\n", + "i: 8\n", + "unique: 8\n", + "i: 9\n", + "unique: 9\n", + "i: 10\n", + "unique: 10\n", + "i: 11\n", + "unique: 11\n", + "i: 12\n", + "unique: 12\n", + "i: 13\n", + "unique: 13\n", + "i: 14\n", + "unique: 14\n", + "i: 15\n", + "unique: 15\n", + "i: 16\n", + "unique: 16\n", + "i: 17\n", + "unique: 17\n", + "i: 18\n", + "unique: 18\n", + "i: 19\n", + "unique: 19\n", + "i: 20\n", + "unique: 20\n", + "i: 21\n", + "unique: 21\n", + "i: 22\n", + "unique: 22\n", + "i: 23\n", + "unique: 23\n", + "i: 24\n", + "unique: 24\n", + "i: 25\n", + "unique: 25\n", + "i: 26\n", + "unique: 26\n", + "i: 27\n", + "unique: 27\n", + "i: 28\n", + "unique: 28\n", + "i: 29\n", + "unique: 29\n", + "i: 30\n", + "unique: 30\n", + "i: 31\n", + "unique: 31\n", + "i: 32\n", + "unique: 32\n", + "i: 33\n", + "unique: 33\n", + "i: 34\n", + "unique: 34\n", + "i: 35\n", + "unique: 35\n", + "i: 36\n", + "unique: 36\n", + "i: 37\n", + "unique: 37\n", + "i: 38\n", + "unique: 38\n", + "i: 39\n", + "unique: 39\n", + "i: 40\n", + "unique: 40\n", + "i: 41\n", + "unique: 41\n", + "i: 42\n", + "unique: 42\n", + "i: 43\n", + "unique: 43\n", + "i: 44\n", + "unique: 44\n", + "i: 45\n", + "unique: 45\n", + "i: 46\n", + "unique: 46\n", + "i: 47\n", + "unique: 47\n", + "i: 48\n", + "unique: 48\n", + "i: 49\n", + "unique: 49\n", + "i: 50\n", + "unique: 50\n", + "i: 51\n", + "unique: 51\n", + "i: 52\n", + "unique: 52\n", + "i: 53\n", + "unique: 53\n", + "i: 54\n", + "unique: 54\n", + "i: 55\n", + "unique: 55\n", + "i: 56\n", + "unique: 56\n", + "i: 57\n", + "unique: 57\n", + "i: 58\n", + "unique: 58\n", + "i: 59\n", + "unique: 59\n", + "i: 60\n", + "unique: 60\n", + "i: 61\n", + "unique: 61\n", + "i: 62\n", + "unique: 62\n", + "i: 63\n", + "unique: 63\n", + "i: 64\n", + "unique: 64\n", + "i: 65\n", + "unique: 65\n", + "i: 66\n", + "unique: 66\n", + "i: 67\n", + "unique: 67\n", + "i: 68\n", + "unique: 68\n", + "i: 69\n", + "unique: 69\n", + "i: 70\n", + "unique: 70\n", + "i: 71\n", + "unique: 71\n", + "i: 72\n", + "unique: 72\n", + "i: 73\n", + "unique: 73\n", + "i: 74\n", + "unique: 74\n", + "i: 75\n", + "unique: 75\n", + "i: 76\n", + "unique: 76\n", + "i: 77\n", + "unique: 77\n", + "i: 78\n", + "unique: 78\n", + "i: 79\n", + "unique: 79\n", + "i: 80\n", + "unique: 80\n", + "i: 81\n", + "unique: 81\n", + "i: 82\n", + "unique: 82\n", + "i: 83\n", + "unique: 83\n", + "i: 84\n", + "unique: 84\n", + "i: 85\n", + "unique: 85\n", + "i: 86\n", + "unique: 86\n", + "i: 87\n", + "unique: 87\n", + "i: 88\n", + "unique: 88\n", + "i: 89\n", + "unique: 89\n", + "i: 90\n", + "unique: 90\n", + "i: 91\n", + "unique: 91\n", + "i: 93\n", + "unique: 93\n", + "i: 94\n", + "unique: 94\n", + "i: 95\n", + "unique: 95\n", + "i: 96\n", + "unique: 96\n", + "i: 97\n", + "unique: 97\n", + "i: 98\n", + "unique: 98\n", + "i: 99\n", + "unique: 99\n", + "i: 100\n", + "unique: 100\n", + "i: 101\n", + "unique: 101\n", + "i: 102\n", + "unique: 102\n", + "i: 103\n", + "unique: 103\n", + "i: 104\n", + "unique: 104\n", + "i: 105\n", + "unique: 105\n", + "i: 106\n", + "unique: 106\n", + "i: 107\n", + "unique: 107\n", + "i: 108\n", + "unique: 108\n", + "i: 109\n", + "unique: 109\n", + "i: 110\n", + "unique: 110\n", + "i: 111\n", + "unique: 111\n", + "i: 112\n", + "unique: 112\n", + "i: 113\n", + "unique: 113\n", + "i: 114\n", + "unique: 114\n", + "i: 115\n", + "unique: 115\n", + "i: 116\n", + "unique: 116\n", + "i: 117\n", + "unique: 117\n", + "i: 118\n", + "unique: 118\n", + "i: 119\n", + "unique: 119\n", + "i: 120\n", + "unique: 120\n", + "i: 121\n", + "unique: 121\n", + "i: 122\n", + "unique: 122\n", + "i: 123\n", + "unique: 123\n", + "i: 124\n", + "unique: 124\n", + "i: 125\n", + "unique: 125\n", + "map_fused_neurons: []\n", + "Calculating the number of neurons not found...\n", "Neurons found: 124\n", "Neurons fused: 0\n", "Neurons not found: 0\n", @@ -157,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -168,145 +425,66 @@ { "data": { "text/plain": [ - "" + "dtype('int32')" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected,name='connected')\n", + "connected.dtype" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -320,7 +498,7 @@ { "data": { "text/plain": [ - "(25, 64, 64)" + "dtype('int64')" ] }, "execution_count": 9, @@ -329,14 +507,12 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" + "gt_labels_resized.dtype" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -353,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -374,15 +550,7 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" - ] - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] From 6812d2c238aa19e96e21cbf522fd68e712958fd9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 013/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../code_models/model_instance_seg.py | 2 +- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 +- .../dev_scripts/evaluate_labels.py | 282 +++++++-- napari_cellseg3d/utils.py | 2 +- notebooks/assess_instance.ipynb | 553 ++++++++---------- requirements.txt | 4 +- setup.cfg | 2 +- 8 files changed, 569 insertions(+), 354 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 6c8ee74f..7ca904c1 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -141,7 +141,7 @@ def voronoi_otsu( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return instance + return np.array(instance) def binary_connected( diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b4436ccb..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,15 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -22,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -29,17 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 - print(f"i: {i}") + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - print(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -47,14 +228,20 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if len(tmp_map) == 1: # map to only one true neuron -> found neuron map_labels_existing.append(tmp_map[0]) @@ -62,17 +249,21 @@ def map_labels(labels, model_labels): # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map - # print(f"map_labels_existing: {map_labels_existing}") - print(f"map_fused_neurons: {map_fused_neurons}") - # print(f"new_labels: {new_labels}") + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -82,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -103,7 +296,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -113,7 +306,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) @@ -123,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -132,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -148,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -183,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -230,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 4c3ef7d4..1590e22a 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -134,7 +134,7 @@ def resize(image, zoom_factors): mode="nearest-exact", padding_mode="empty", )(np.expand_dims(image, axis=0)) - return isotropic_image[0] + return isotropic_image[0].numpy() def align_array_sizes(array_shape, target_shape): diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 6e6a9b5f..d521c395 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -115,279 +110,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "i: 1\n", - "unique: 1\n", - "i: 2\n", - "unique: 2\n", - "i: 3\n", - "unique: 3\n", - "i: 4\n", - "unique: 4\n", - "i: 5\n", - "unique: 5\n", - "i: 6\n", - "unique: 6\n", - "i: 7\n", - "unique: 7\n", - "i: 8\n", - "unique: 8\n", - "i: 9\n", - "unique: 9\n", - "i: 10\n", - "unique: 10\n", - "i: 11\n", - "unique: 11\n", - "i: 12\n", - "unique: 12\n", - "i: 13\n", - "unique: 13\n", - "i: 14\n", - "unique: 14\n", - "i: 15\n", - "unique: 15\n", - "i: 16\n", - "unique: 16\n", - "i: 17\n", - "unique: 17\n", - "i: 18\n", - "unique: 18\n", - "i: 19\n", - "unique: 19\n", - "i: 20\n", - "unique: 20\n", - "i: 21\n", - "unique: 21\n", - "i: 22\n", - "unique: 22\n", - "i: 23\n", - "unique: 23\n", - "i: 24\n", - "unique: 24\n", - "i: 25\n", - "unique: 25\n", - "i: 26\n", - "unique: 26\n", - "i: 27\n", - "unique: 27\n", - "i: 28\n", - "unique: 28\n", - "i: 29\n", - "unique: 29\n", - "i: 30\n", - "unique: 30\n", - "i: 31\n", - "unique: 31\n", - "i: 32\n", - "unique: 32\n", - "i: 33\n", - "unique: 33\n", - "i: 34\n", - "unique: 34\n", - "i: 35\n", - "unique: 35\n", - "i: 36\n", - "unique: 36\n", - "i: 37\n", - "unique: 37\n", - "i: 38\n", - "unique: 38\n", - "i: 39\n", - "unique: 39\n", - "i: 40\n", - "unique: 40\n", - "i: 41\n", - "unique: 41\n", - "i: 42\n", - "unique: 42\n", - "i: 43\n", - "unique: 43\n", - "i: 44\n", - "unique: 44\n", - "i: 45\n", - "unique: 45\n", - "i: 46\n", - "unique: 46\n", - "i: 47\n", - "unique: 47\n", - "i: 48\n", - "unique: 48\n", - "i: 49\n", - "unique: 49\n", - "i: 50\n", - "unique: 50\n", - "i: 51\n", - "unique: 51\n", - "i: 52\n", - "unique: 52\n", - "i: 53\n", - "unique: 53\n", - "i: 54\n", - "unique: 54\n", - "i: 55\n", - "unique: 55\n", - "i: 56\n", - "unique: 56\n", - "i: 57\n", - "unique: 57\n", - "i: 58\n", - "unique: 58\n", - "i: 59\n", - "unique: 59\n", - "i: 60\n", - "unique: 60\n", - "i: 61\n", - "unique: 61\n", - "i: 62\n", - "unique: 62\n", - "i: 63\n", - "unique: 63\n", - "i: 64\n", - "unique: 64\n", - "i: 65\n", - "unique: 65\n", - "i: 66\n", - "unique: 66\n", - "i: 67\n", - "unique: 67\n", - "i: 68\n", - "unique: 68\n", - "i: 69\n", - "unique: 69\n", - "i: 70\n", - "unique: 70\n", - "i: 71\n", - "unique: 71\n", - "i: 72\n", - "unique: 72\n", - "i: 73\n", - "unique: 73\n", - "i: 74\n", - "unique: 74\n", - "i: 75\n", - "unique: 75\n", - "i: 76\n", - "unique: 76\n", - "i: 77\n", - "unique: 77\n", - "i: 78\n", - "unique: 78\n", - "i: 79\n", - "unique: 79\n", - "i: 80\n", - "unique: 80\n", - "i: 81\n", - "unique: 81\n", - "i: 82\n", - "unique: 82\n", - "i: 83\n", - "unique: 83\n", - "i: 84\n", - "unique: 84\n", - "i: 85\n", - "unique: 85\n", - "i: 86\n", - "unique: 86\n", - "i: 87\n", - "unique: 87\n", - "i: 88\n", - "unique: 88\n", - "i: 89\n", - "unique: 89\n", - "i: 90\n", - "unique: 90\n", - "i: 91\n", - "unique: 91\n", - "i: 93\n", - "unique: 93\n", - "i: 94\n", - "unique: 94\n", - "i: 95\n", - "unique: 95\n", - "i: 96\n", - "unique: 96\n", - "i: 97\n", - "unique: 97\n", - "i: 98\n", - "unique: 98\n", - "i: 99\n", - "unique: 99\n", - "i: 100\n", - "unique: 100\n", - "i: 101\n", - "unique: 101\n", - "i: 102\n", - "unique: 102\n", - "i: 103\n", - "unique: 103\n", - "i: 104\n", - "unique: 104\n", - "i: 105\n", - "unique: 105\n", - "i: 106\n", - "unique: 106\n", - "i: 107\n", - "unique: 107\n", - "i: 108\n", - "unique: 108\n", - "i: 109\n", - "unique: 109\n", - "i: 110\n", - "unique: 110\n", - "i: 111\n", - "unique: 111\n", - "i: 112\n", - "unique: 112\n", - "i: 113\n", - "unique: 113\n", - "i: 114\n", - "unique: 114\n", - "i: 115\n", - "unique: 115\n", - "i: 116\n", - "unique: 116\n", - "i: 117\n", - "unique: 117\n", - "i: 118\n", - "unique: 118\n", - "i: 119\n", - "unique: 119\n", - "i: 120\n", - "unique: 120\n", - "i: 121\n", - "unique: 121\n", - "i: 122\n", - "unique: 122\n", - "i: 123\n", - "unique: 123\n", - "i: 124\n", - "unique: 124\n", - "i: 125\n", - "unique: 125\n", - "map_fused_neurons: []\n", - "Calculating the number of neurons not found...\n", - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -414,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -428,66 +165,177 @@ "dtype('int32')" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')\n", + "viewer.add_labels(connected, name=\"connected\")\n", "connected.dtype" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -501,7 +349,7 @@ "dtype('int64')" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -512,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -522,42 +370,155 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { diff --git a/requirements.txt b/requirements.txt index ead0052c..834a225e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ black coverage isort +itk pytest pytest-qt sphinx @@ -18,6 +19,7 @@ matplotlib>=3.4.1 tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,scikit-image,itk,einops]>=0.9.0 +monai[nibabel,einops]>=1.0.1 pillow +scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index 37feca98..9cff5fa8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=0.9.0 + monai[nibabel,einops]>=1.0.1 itk tqdm nibabel From e243b27d8d371df1730bf5977c1f1e48cf6ffa9a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 014/577] black --- .../code_models/model_instance_seg.py | 21 ++++++++---- napari_cellseg3d/code_models/model_workers.py | 4 ++- .../code_plugins/plugin_model_inference.py | 8 +++-- napari_cellseg3d/config.py | 2 ++ .../dev_scripts/evaluate_labels.py | 33 +++++++++++-------- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 7ca904c1..c951f176 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -38,7 +38,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None + widget_parent: QWidget = None, ): """ Methods for instance segmentation @@ -61,7 +61,14 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), + ui.Slider( + 0, + 100, + 1, + divide_factor=100, + text_label="", + parent=None, + ), ) self.sliders.append(getattr(self, widget)) @@ -372,13 +379,13 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -418,13 +425,13 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -455,7 +462,7 @@ def __init__(self, widget_parent): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index ad0b447e..b5ae3920 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -551,7 +551,9 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct + instance_labels = np.swapaxes( + instance_labels, 0, 2 + ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 2ad7371c..0e7b05c6 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -555,7 +555,9 @@ def start(self): self.instance_config = config.InstanceSegConfig( enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], ) self.post_process_config = config.PostProcessConfig( @@ -728,7 +730,9 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 9f94ff1f..f0db27cc 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -118,11 +118,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class PostProcessConfig: zoom: Zoom = Zoom() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..1aa52932 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -10,11 +10,14 @@ PERCENT_CORRECT = 0.7 + @dataclass class LabelInfo: gt_index: int model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + best_model_label_coverage: float = ( + 0.0 # ratio of pixels of the gt label correctly labelled + ) overall_gt_label_coverage: float = 0.0 # true positive ration of the model def get_correct_ratio(self): @@ -24,16 +27,25 @@ def get_correct_ratio(self): else: return None + def eval_model(gt_labels, model_labels, print_report=False): - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + report_list, new_labels, fused_labels = create_label_report( + gt_labels, model_labels + ) per_label_perfs = [] for report in report_list: if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + log.info( + f"Label {report.gt_index} : {report.model_labels_id_and_status}" + ) + log.info( + f"Best model label coverage : {report.best_model_label_coverage}" + ) + log.info( + f"Overall gt label coverage : {report.overall_gt_label_coverage}" + ) perf = report.get_correct_ratio() if perf is not None: @@ -43,8 +55,6 @@ def eval_model(gt_labels, model_labels, print_report=False): return per_label_perfs.mean(), new_labels, fused_labels - - def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -63,7 +73,6 @@ def create_label_report(gt_labels, model_labels): The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ - map_labels_existing = [] map_fused_neurons = {} "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" @@ -135,9 +144,7 @@ def create_label_report(gt_labels, model_labels): # log.debug(ratio) ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) + info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() best_model_lab_id = model_lab_on_gt[ np.argmax(ratio_model_lab_for_given_gt_lab) @@ -369,9 +376,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: From cc549df86f1db87de59fd956f5a2b2bbc59296fc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 015/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 564 +++++++++--------- notebooks/assess_instance.ipynb | 290 ++++----- 2 files changed, 385 insertions(+), 469 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 1aa52932..3082e79f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,275 +1,15 @@ import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = ( - 0.0 # ratio of pixels of the gt label correctly labelled - ) - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report( - gt_labels, model_labels - ) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info( - f"Label {report.gt_index} : {report.model_labels_id_and_status}" - ) - log.info( - f"Best model label coverage : {report.best_model_label_coverage}" - ) - log.info( - f"Overall gt label coverage : {report.overall_gt_label_coverage}" - ) - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False ): """Evaluate the model performance. Parameters @@ -278,7 +18,7 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - do_print : bool + print_details : bool If True, print the results. visualize : bool If True, visualize the results. @@ -305,7 +45,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels + labels, model_labels, threshold_correct ) # calculate the number of neurons individually found @@ -351,33 +91,30 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") + log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") + log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") + log.info( + f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + ) log.info( - "Mean true positive ratio of the model: ", + f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" ) - log.info(mean_true_positive_ratio_model) log.info( - "Mean ratio of the neurons pixels correctly labelled: ", + f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" ) - log.info(mean_ratio_pixel_found) log.info( - "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" ) - log.info(mean_ratio_pixel_found_fused) log.info( - "Mean true positive ratio of the model for fused neurons: ", + f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" ) - log.info(mean_true_positive_ratio_model_fused) - log.info("Mean ratio of false pixel in artefacts: ") - log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -436,6 +173,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file @@ -464,6 +276,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons # if __name__ == "__main__": # """ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index d521c395..4bf89452 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,9 +4,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -22,6 +19,7 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", + " to_semantic,\n", ")" ] }, @@ -29,9 +27,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -50,12 +45,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -72,9 +69,7 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" + "viewer.add_labels(gt_labels_resized, name=\"gt\")" ] }, { @@ -84,9 +79,33 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -98,7 +117,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -110,48 +143,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "2\n" ] - }, - { - "data": { - "text/plain": [ - "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)\n", + "print(np.unique(gt_labels_resized).shape[0])" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -162,23 +168,22 @@ { "data": { "text/plain": [ - "dtype('int32')" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized,thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -190,21 +195,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:05,891 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -217,18 +225,10 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" + "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -251,21 +251,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:05,995 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -278,25 +281,17 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", + " prediction_resized, thres_small=2, rem_seed_thres=1\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -304,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -318,24 +313,24 @@ "(25, 64, 64)" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", + "voronoi = remove_small_objects(voronoi, 2)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { @@ -349,7 +344,7 @@ "dtype('int64')" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -360,104 +355,35 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(voronoi, return_counts=True)" + "# np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(gt_labels_resized, return_counts=True)" + "# np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { @@ -469,21 +395,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:06,360 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -496,18 +425,10 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -518,14 +439,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, "outputs": [], From d9f423fedd33e0d955bc2e257c06e9f83d8f08f5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 016/577] Added pre-commit hooks --- .pre-commit-config.yaml | 69 +++++++++++++++++++++++------------------ requirements.txt | 2 ++ 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7d452eba..802dfe20 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,39 +1,46 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 - hooks: - - id: check-docstring-first - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/asottile/setup-cfg-fmt - rev: v1.20.0 - hooks: - - id: setup-cfg-fmt - - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 - additional_dependencies: [flake8-typing-imports>=1.9.0] - - repo: https://github.com/myint/autoflake - rev: v1.4 - hooks: - - id: autoflake - args: ["--in-place", "--remove-all-unused-imports"] - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort +# - repo: https://github.com/pre-commit/pre-commit-hooks +# rev: v4.0.1 +# hooks: +# - id: check-docstring-first +# - id: end-of-file-fixer +# - id: trailing-whitespace +# - repo: https://github.com/asottile/setup-cfg-fmt +# rev: v1.20.0 +# hooks: +# - id: setup-cfg-fmt +# - repo: https://github.com/PyCQA/flake8 +# rev: 4.0.1 +# hooks: +# - id: flake8 +# additional_dependencies: [flake8-typing-imports>=1.9.0] +# - repo: https://github.com/myint/autoflake +# rev: v1.4 +# hooks: +# - id: autoflake +# args: ["--in-place", "--remove-all-unused-imports"] +# - repo: https://github.com/PyCQA/isort +# rev: 5.10.1 +# hooks: +# - id: isort + - repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: 'v0.0.257' + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 21.11b1 + rev: 22.3.0 hooks: - id: black - - repo: https://github.com/asottile/pyupgrade - rev: v2.29.1 - hooks: - - id: pyupgrade - args: [--py38-plus, --keep-runtime-typing] + args: [--line-length=88] +# - repo: https://github.com/asottile/pyupgrade +# rev: v2.29.1 +# hooks: +# - id: pyupgrade +# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks - rev: v0.2.0 + rev: v0.3.0 hooks: - id: napari-plugin-checks # https://mypy.readthedocs.io/en/stable/introduction.html diff --git a/requirements.txt b/requirements.txt index 834a225e..3189e9c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pre-commit pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 818a1a851910a7a81ef63475cc9e1e300120d2f4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:40:31 +0100 Subject: [PATCH 017/577] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 802dfe20..d1e22fb1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: rev: 22.3.0 hooks: - id: black - args: [--line-length=88] + args: [--line-length=79] # - repo: https://github.com/asottile/pyupgrade # rev: v2.29.1 # hooks: From 7fa0c787c2db26da353a7231b77f5ba670ddbe4b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:48:32 +0100 Subject: [PATCH 018/577] Update pyproject.toml --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cb96e206..1af55e91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,9 @@ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" - +[tool.ruff] +# Never enforce `E501` (line length violations). +ignore = ["E501"] [tool.black] line-length = 79 From 7e428949e8d3d4fed2c1a47969349b0ec74c058e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:50:33 +0100 Subject: [PATCH 019/577] Update pyproject.toml Ruff config --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1af55e91..8d596d54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [tool.ruff] # Never enforce `E501` (line length violations). -ignore = ["E501"] +ignore = ["E501", "E741"] [tool.black] line-length = 79 From 72eb89c006d3046dc716b43ffbec1de6f1e85078 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 020/577] Enfore pre-commit style --- .gitignore | 1 + napari_cellseg3d/_tests/test_dock_widget.py | 2 +- napari_cellseg3d/_tests/test_helper.py | 1 - .../_tests/test_plugin_inference.py | 1 - napari_cellseg3d/_tests/test_review.py | 1 - napari_cellseg3d/_tests/test_utils.py | 5 -- .../code_models/model_framework.py | 1 - .../code_models/model_instance_seg.py | 10 +-- napari_cellseg3d/code_models/model_workers.py | 32 +++----- .../code_models/models/model_TRAILMAP.py | 6 -- .../code_models/models/model_TRAILMAP_MS.py | 1 - napari_cellseg3d/code_plugins/plugin_base.py | 3 +- .../code_plugins/plugin_convert.py | 9 --- napari_cellseg3d/code_plugins/plugin_crop.py | 4 - .../code_plugins/plugin_metrics.py | 2 +- .../code_plugins/plugin_model_inference.py | 14 +--- .../code_plugins/plugin_model_training.py | 6 +- .../code_plugins/plugin_review.py | 5 -- .../code_plugins/plugin_review_dock.py | 3 +- .../code_plugins/plugin_utilities.py | 2 - napari_cellseg3d/config.py | 5 +- .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 23 ++++-- napari_cellseg3d/dev_scripts/thread_test.py | 1 - napari_cellseg3d/interface.py | 13 +-- napari_cellseg3d/utils.py | 13 ++- notebooks/assess_instance.ipynb | 79 +++++++++++++------ notebooks/csv_cell_plot.ipynb | 2 - 29 files changed, 102 insertions(+), 145 deletions(-) diff --git a/.gitignore b/.gitignore index ffe6e1f8..20fdcd08 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,4 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png +notebooks/instance_test.ipynb diff --git a/napari_cellseg3d/_tests/test_dock_widget.py b/napari_cellseg3d/_tests/test_dock_widget.py index f621dba4..7737e540 100644 --- a/napari_cellseg3d/_tests/test_dock_widget.py +++ b/napari_cellseg3d/_tests/test_dock_widget.py @@ -15,7 +15,7 @@ def test_prepare(make_napari_viewer): widget.prepare(path_image, ".tif", "", False) assert widget.filetype == ".tif" - assert widget.as_folder == False + assert widget.as_folder is False assert Path(widget.csv_path) == ( Path(__file__).resolve().parent / "res/_train0.csv" ) diff --git a/napari_cellseg3d/_tests/test_helper.py b/napari_cellseg3d/_tests/test_helper.py index b35fc111..7c93b7d5 100644 --- a/napari_cellseg3d/_tests/test_helper.py +++ b/napari_cellseg3d/_tests/test_helper.py @@ -2,7 +2,6 @@ def test_helper(make_napari_viewer): - viewer = make_napari_viewer() widget = Helper(viewer) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 5b89c065..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -9,7 +9,6 @@ def test_inference(make_napari_viewer, qtbot): - im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/_tests/test_review.py b/napari_cellseg3d/_tests/test_review.py index d2b49061..fb61fb29 100644 --- a/napari_cellseg3d/_tests/test_review.py +++ b/napari_cellseg3d/_tests/test_review.py @@ -4,7 +4,6 @@ def test_launch_review(make_napari_viewer): - view = make_napari_viewer() widget = rev.Reviewer(view) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 6a7b6eeb..dc57b940 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -9,7 +9,6 @@ def test_fill_list_in_between(): - list = [1, 2, 3, 4, 5, 6] res = [ 1, @@ -36,7 +35,6 @@ def test_fill_list_in_between(): def test_align_array_sizes(): - im = np.zeros((128, 512, 256)) print(im.shape) @@ -71,7 +69,6 @@ def test_align_array_sizes(): def test_get_padding_dim(): - tensor = torch.randn(100, 30, 40) size = tensor.size() @@ -103,14 +100,12 @@ def test_get_padding_dim(): def test_normalize_x(): - test_array = utils.normalize_x(np.array([0, 255, 127.5])) expected = np.array([-1, 1, 0]) assert np.all(test_array == expected) def test_parse_default_path(): - user_path = os.path.expanduser("~") assert utils.parse_default_path([None]) == user_path diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 88d3f887..b3121cf4 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -191,7 +191,6 @@ def display_status_report(self): if self.container_docked: self.log.clear() elif not self.container_docked: - ui.add_widgets( self.report_container.layout, [self.progress, self.log, self.btn_save_log], diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index c951f176..30d147d3 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,18 +1,13 @@ -from __future__ import division -from __future__ import print_function - from dataclasses import dataclass from typing import List import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.transform import resize # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -538,14 +533,13 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index b5ae3920..61b9aef9 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -137,7 +137,6 @@ def show_progress(count, block_size, total_size): with tarfile.open(filename, mode="r:gz") as tar: def is_within_directory(directory, target): - abs_directory = Path(directory).resolve() abs_target = Path(target).resolve() # prefix = os.path.commonprefix([abs_directory, abs_target]) @@ -150,7 +149,6 @@ def is_within_directory(directory, target): def safe_extract( tar, path=".", members=None, *, numeric_owner=False ): - for member in tar.getmembers(): member_path = str(Path(path) / member.name) if not is_within_directory(path, member_path): @@ -274,7 +272,6 @@ def warn(self, warning): self.warn_signal.emit(warning) def log_parameters(self): - config = self.config self.log("-" * 20) @@ -301,7 +298,7 @@ def log_parameters(self): ) if config.keep_on_cpu: - self.log(f"Dataset loaded to CPU") + self.log("Dataset loaded to CPU") else: self.log(f"Dataset loaded on {config.device}") @@ -318,7 +315,6 @@ def log_parameters(self): self.log("-" * 20) def load_folder(self): - images_dict = self.create_inference_dict(self.config.images_filepaths) # TODO : better solution than loading first image always ? @@ -447,14 +443,12 @@ def model_output( post_process=True, aniso_transform=None, ): - inputs = inputs.to("cpu") - model_output = lambda inputs: post_process_transforms( - self.config.model_info.get_model().get_output( - model, inputs - ) # TODO(cyril) refactor those functions - ) + # def model_output(inputs): + # return post_process_transforms( + # self.config.model_info.get_model().get_output(model, inputs) + # ) def model_output(inputs): return post_process_transforms( @@ -513,7 +507,6 @@ def create_result_dict( # FIXME replace with result class stats=None, i=0, ): - if not from_layer and original is None: raise ValueError( "If the image is not from a layer, an original should always be available" @@ -539,7 +532,6 @@ def get_original_filename(self, i): return Path(self.config.images_filepaths[i]).stem def get_instance_result(self, semantic_labels, from_layer=False, i=-1): - if not from_layer and i == -1: raise ValueError( "An ID should be provided when running from a file" @@ -566,7 +558,6 @@ def save_image( from_layer=False, i=0, ): - if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" else: @@ -592,7 +583,6 @@ def save_image( self.log(f"\nFile n°{i+1} saved as : {filename}") def aniso_transform(self, image): - if self.config.post_process_config.zoom.enabled: zoom = self.config.post_process_config.zoom.zoom_values anisotropic_transform = Zoom( @@ -605,7 +595,6 @@ def aniso_transform(self, image): return image def instance_seg(self, to_instance, image_id=0, original_filename="layer"): - if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") @@ -631,7 +620,6 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): return instance_labels def inference_on_folder(self, inf_data, i, model, post_process_transforms): - self.log("-" * 10) self.log(f"Inference started on image n°{i + 1}...") @@ -675,7 +663,7 @@ def stats_csv(self, instance_labels): def inference_on_layer(self, image, model, post_process_transforms): self.log("-" * 10) - self.log(f"Inference started on layer...") + self.log("Inference started on layer...") image = image.type(torch.FloatTensor) @@ -939,7 +927,6 @@ def warn(self, warning): self.warn_signal.emit(warning) def log_parameters(self): - self.log("-" * 20) self.log("Parameters summary :\n") @@ -962,7 +949,7 @@ def log_parameters(self): self.log("-" * 10) if self.config.deterministic_config.enabled: - self.log(f"Deterministic training is enabled") + self.log("Deterministic training is enabled") self.log(f"Seed is {self.config.deterministic_config.seed}") self.log(f"Training for {self.config.max_epochs} epochs") @@ -1117,10 +1104,10 @@ def train(self): logger.debug(f"SAMPLING is {self.config.sampling}") if not self.config.sampling: - msg += f"Sampling is not in use, the only image provided will be used as the validation file." + msg += "Sampling is not in use, the only image provided will be used as the validation file." self.warn(msg) else: - msg += f"Samples for validation will be cropped for the same only volume that is being used for training" + msg += "Samples for validation will be cropped for the same only volume that is being used for training" logger.warning(msg) @@ -1191,7 +1178,6 @@ def train(self): ) # self.log("Loading dataset...\n") if do_sampling: - # if there is only one volume, split samples # TODO(cyril) : maybe implement something in user config to toggle this behavior if len(self.config.train_data_dict) < 2: diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 7cdf9b80..09de2a26 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -19,7 +19,6 @@ def get_output(model, input): def get_validation(model, val_inputs): - return model(val_inputs) @@ -41,7 +40,6 @@ def __init__(self, in_ch, out_ch): self.out = self.outBlock(32, out_ch, 1) def forward(self, x): - conv0 = self.conv0(x) # l0 conv1 = self.conv1(conv0) # l1 conv2 = self.conv2(conv1) # l2 @@ -67,7 +65,6 @@ def forward(self, x): return out def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), @@ -82,7 +79,6 @@ def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): return encode def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), @@ -96,7 +92,6 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): return encode def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - decode = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), @@ -113,7 +108,6 @@ def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): return decode def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - out = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index d62fee26..0fc68d34 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -17,5 +17,4 @@ def get_output(model, input): def get_validation(model, val_inputs): - return model(val_inputs) diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 8e0fab3c..7c5fbaa5 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -345,7 +345,8 @@ def __init__( * A button to set a results folder - * A dropdown menu to select the file extension to be loaded from the folders""" + * A dropdown menu to select the file extension to be loaded from the folders + """ super().__init__( viewer, parent, loads_images, loads_labels, has_results ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index ebcff9d7..37051fcc 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -4,12 +4,10 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QWidget from tifffile import imread from tifffile import imwrite import napari_cellseg3d.interface as ui -from napari_cellseg3d import config from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets @@ -123,7 +121,6 @@ def __init__(self, viewer: "napari.Viewer.viewer", parent=None): self._build() def _build(self): - container = ui.ContainerWidget() ui.add_widgets( @@ -147,7 +144,6 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) zoom = self.aniso_widgets.scaling_zyx() @@ -225,7 +221,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.function = clear_small_objects def _build(self): - container = ui.ContainerWidget() ui.add_widgets( @@ -311,7 +306,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self._build() def _build(self): - container = ui.ContainerWidget() ui.add_widgets( @@ -395,7 +389,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self._build() def _build(self): - container = ui.ContainerWidget() ui.add_widgets( @@ -458,7 +451,6 @@ class ThresholdUtils(BasePluginFolder): """ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): - super().__init__( viewer, parent, @@ -489,7 +481,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.function = threshold def _build(self): - container = ui.ContainerWidget() ui.add_widgets( diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 153b5e69..07885236 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -122,7 +122,6 @@ def _toggle_second_image_io_visibility(self): self.labels_filewidget.setVisible(crop_2nd) def _check_image_list(self): - l1 = self.image_layer_loader.layer_list l2 = self.label_layer_loader.layer_list @@ -232,7 +231,6 @@ def quicksave(self): logger.info(f"Image 2 saved as: {im2_path}") def _check_ready(self): - if self.image_layer_loader.layer_data() is not None: if self.crop_second_image: if self.label_layer_loader.layer_data() is not None: @@ -367,7 +365,6 @@ def add_isotropic_layer( return layer def _check_for_empty_layer(self, layer, volume_data): - if layer.data.all() == np.zeros_like(layer.data).all(): layer.colormap = "red" layer.data = np.random.random(layer.data.shape) @@ -378,7 +375,6 @@ def _check_for_empty_layer(self, layer, volume_data): layer.refresh() def _add_crop_layer(self, layer, cropx, cropy, cropz): - crop_data = layer.data[:cropx, :cropy, :cropz] if isinstance(layer, napari.layers.Image): diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 42c2d89e..b2356526 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -162,7 +162,7 @@ def plot_dice(self, dice_coeffs, threshold=DEFAULT_THRESHOLD): f"Session {len(self.plots)}\nMean dice : {np.mean(dice_coeffs):.4f}" ) # dice_plot.set_xticks(rotation=45) - dice_plot.set_xlabel(f"Dice coefficient") + dice_plot.set_xlabel("Dice coefficient") # dice_plot.set_ylabel("Labels pair id", rotation=90) self.canvas.draw_idle() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 0e7b05c6..483679ef 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,9 +10,6 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult @@ -591,12 +588,10 @@ def start(self): ##################### if self.folder_choice.isChecked(): - self.worker_config.images_filepaths = self.images_filepaths self.worker = InferenceWorker(worker_config=self.worker_config) elif self.layer_choice.isChecked(): - self.worker_config.layer = self.image_layer_loader.layer_data() self.worker = InferenceWorker(worker_config=self.worker_config) @@ -701,14 +696,13 @@ def on_yield(self, result: InferenceResult): self.config.show_results and image_id <= self.config.show_results_count ): - zoom = self.worker_config.post_process_config.zoom.zoom_values viewer.dims.ndisplay = 3 viewer.scale_bar.visible = True if self.config.show_original and result.original is not None: - original_layer = viewer.add_image( + viewer.add_image( result.original, colormap="inferno", name=f"original_{image_id}", @@ -720,7 +714,7 @@ def on_yield(self, result: InferenceResult): if self.worker_config.post_process_config.thresholding.enabled: out_colormap = "turbo" - out_layer = viewer.add_image( + viewer.add_image( result.result, colormap=out_colormap, name=f"pred_{image_id}_{model_name}", @@ -728,7 +722,6 @@ def on_yield(self, result: InferenceResult): ) if result.instance_labels is not None: - labels = result.instance_labels method_name = ( self.worker_config.post_process_config.instance.method.name @@ -740,12 +733,11 @@ def on_yield(self, result: InferenceResult): name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - instance_layer = viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) stats = result.stats if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() stats_df = pd.DataFrame(stats_dict) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index ac8aefc3..de54b345 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -893,7 +893,7 @@ def on_finish(self): self.log.print_and_log( f"Saving in {self.worker_config.results_path_folder}" ) - self.log.print_and_log(f"Saving last loss plot") + self.log.print_and_log("Saving last loss plot") plot_name = self.worker_config.results_path_folder / Path( f"final_metric_plots_{utils.get_time_filepath()}.png" @@ -954,7 +954,6 @@ def on_yield(self, report: TrainingReport): return if report.show_plot: - try: layer_name = "Training_checkpoint_" rge = range(len(report.images)) @@ -1022,7 +1021,6 @@ def on_yield(self, report: TrainingReport): # self.empty_cuda_cache() def _make_csv(self): - size_column = range(1, self.worker_config.max_epochs + 1) if len(self.loss_values) == 0 or self.loss_values is None: @@ -1119,7 +1117,6 @@ def update_loss_plot(self, loss, metric): elif epoch == self.worker_config.validation_interval * 2: bckgrd_color = (0, 0, 0, 0) # '#262930' with plt.style.context("dark_background"): - self.canvas = FigureCanvas(Figure(figsize=(10, 1.5))) # loss plot self.train_loss_plot = self.canvas.figure.add_subplot(1, 2, 1) @@ -1162,7 +1159,6 @@ def update_loss_plot(self, loss, metric): self.plot_loss(loss, metric) else: with plt.style.context("dark_background"): - self.train_loss_plot.cla() self.dice_metric_plot.cla() diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 1ad84667..a803dfd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -189,7 +189,6 @@ def check_image_data(self): ) def _prepare_data(self): - if self.layer_choice.isChecked(): self.config.image = self.image_layer_loader.layer_data() self.config.labels = self.label_layer_loader.layer_data() @@ -217,7 +216,6 @@ def _prepare_data(self): self.config.zoom_factor = zoom def run_review(self): - """Launches review process by loading the files from the chosen folders, and adds several widgets to the napari Viewer. If the review process has been launched once before, @@ -236,7 +234,6 @@ def run_review(self): print("New review session\n" + "*" * 20) previous_viewer = self._viewer try: - self._prepare_data() self._viewer, self.docked_widgets = self.launch_review() @@ -368,7 +365,6 @@ def quicksave(): @viewer.mouse_drag_callbacks.append def update_canvas_canvas(viewer, event): - if "shift" in event.modifiers: try: cursor_position = np.round(viewer.cursor.position).astype( @@ -424,7 +420,6 @@ def update_canvas_canvas(viewer, event): datamananger._close_btn = False def update_button(axis_event): - slice_num = axis_event.value[0] logger.debug(f"slice num is {slice_num}") dmg.update_dm(slice_num) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index 02a1c474..8a25d6a6 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -36,7 +36,8 @@ def __init__(self, parent: "napari.viewer.Viewer"): """Creates the datamanager widget in the specified viewer window. Args: - parent (napari.viewer.Viewer): napari Viewer for the widget to be displayed in""" + parent (napari.viewer.Viewer): napari Viewer for the widget to be displayed in + """ super().__init__() diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index ad1f5547..c962717e 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,7 +2,6 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QLayout from qtpy.QtWidgets import QSizePolicy from qtpy.QtWidgets import QVBoxLayout from qtpy.QtWidgets import QWidget @@ -53,7 +52,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"): qInstallMessageHandler(ui.handle_adjust_errors_wrapper(self)) def _build(self): - layout = QVBoxLayout() ui.add_widgets(layout, self.utils_widgets) layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index f0db27cc..ab3dba39 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,10 +8,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -91,7 +88,7 @@ def get_model(self): @staticmethod def get_model_name_list(): logger.info( - f"Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) + "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) ) return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3082e79f..a972fa69 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -5,11 +5,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -91,9 +95,15 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") - log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") - log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) if print_details: log.info(f"Neurons found: {neurons_found}") @@ -131,7 +141,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -276,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 15b1b469..998645cb 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -127,7 +127,6 @@ def on_finish(): if __name__ == "__main__": - viewer = napari.view_image(np.random.rand(512, 512)) w = create_connected_widget(viewer) viewer.window.add_dock_widget(w) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index e6db6930..073d2b8b 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -350,7 +350,6 @@ def __init__( class RadioButton(QRadioButton): def __init__(self, text: str = None, parent=None): - super().__init__(text, parent) @@ -451,7 +450,6 @@ def __init__( orientation=Qt.Horizontal, text_label: str = None, ): - super().__init__(orientation, parent) self.setMaximum(upper) @@ -656,7 +654,8 @@ def __init__( def _toggle_display_aniso(self): """Shows the choices for correcting anisotropy - when viewing results depending on whether :py:attr:`self.checkbox` is checked""" + when viewing results depending on whether :py:attr:`self.checkbox` is checked + """ toggle_visibility(self.checkbox, self.container) def build(self): @@ -737,24 +736,20 @@ def __init__( self._check_for_layers() def _check_for_layers(self): - for layer in self._viewer.layers: if isinstance(layer, self.layer_type): self.layer_list.addItem(layer.name) def _update_tooltip(self): - self.layer_list.setToolTip(self.layer_list.currentText()) def _add_layer(self, event): - inserted_layer = event.value if isinstance(inserted_layer, self.layer_type): self.layer_list.addItem(inserted_layer.name) def _remove_layer(self, event): - removed_layer = event.value if isinstance( @@ -762,7 +757,6 @@ def _remove_layer(self, event): ) and removed_layer.name in [ self.layer_list.itemText(i) for i in range(self.layer_list.count()) ]: - index = self.layer_list.findText(removed_layer.name) self.layer_list.removeItem(index) @@ -1044,7 +1038,8 @@ def __init__( step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed - label (Optional[str]): if provided, creates a label with the chosen title to use with the counter""" + label (Optional[str]): if provided, creates a label with the chosen title to use with the counter + """ super().__init__(parent) set_spinbox(self, lower, upper, default, step, fixed) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 1590e22a..090e305b 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -36,9 +36,7 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) + cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] @@ -161,7 +159,7 @@ def align_array_sizes(array_shape, target_shape): for i in range(len(targets)): targets[i] = reverse_mapping[targets[i]] infos = np.unique(origins, return_index=True, return_counts=True) - info_dict = {"origins": infos[0], "index": infos[1], "counts": infos[2]} + {"origins": infos[0], "index": infos[1], "counts": infos[2]} # print(info_dict) final_orig = [] @@ -229,7 +227,6 @@ def get_padding_dim(image_shape, anisotropy_factor=None): # problems with zero divs avoided via params for spinboxes size = int(size / anisotropy_factor[i]) while pad < size: - # if size - pad < 30: # warnings.warn( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" @@ -414,17 +411,17 @@ def parse_default_path(possible_paths): def get_date_time(): """Get date and time in the following format : year_month_day_hour_minute_second""" - return "{:%Y_%m_%d_%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%Y_%m_%d_%H_%M_%S}" def get_time(): """Get time in the following format : hour:minute:second. NOT COMPATIBLE with file paths (saving with ":" is invalid)""" - return "{:%H:%M:%S}".format(datetime.now()) + return f"{datetime.now():%H:%M:%S}" def get_time_filepath(): """Get time in the following format : hour_minute_second. Compatible with saving""" - return "{:%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%H_%M_%S}" def load_images(dir_or_path, filetype="", as_folder: bool = False): diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 4bf89452..b8810301 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -47,7 +47,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -96,7 +96,10 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", - "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" ] }, { @@ -145,7 +148,7 @@ "text": [ "(25, 64, 64)\n", "(25, 64, 64)\n", - "2\n" + "125\n" ] } ], @@ -168,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -177,7 +180,7 @@ } ], "source": [ - "connected = binary_connected(prediction_resized,thres_small=2)\n", + "connected = binary_connected(prediction_resized, thres_small=2)\n", "viewer.add_labels(connected, name=\"connected\")" ] }, @@ -195,24 +198,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,891 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -225,7 +228,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -251,24 +262,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,995 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -281,7 +292,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -395,24 +414,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,360 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -425,7 +444,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 15, diff --git a/notebooks/csv_cell_plot.ipynb b/notebooks/csv_cell_plot.ipynb index 8b14fb8d..e00a9f1c 100644 --- a/notebooks/csv_cell_plot.ipynb +++ b/notebooks/csv_cell_plot.ipynb @@ -58,7 +58,6 @@ "outputs": [], "source": [ "def plot_data(data_path, x_inv=False, y_inv=False, z_inv=False):\n", - "\n", " data = pd.read_csv(data_path, index_col=False)\n", "\n", " x = data[\"Centroid x\"]\n", @@ -185,7 +184,6 @@ "outputs": [], "source": [ "def plotly_cells_stats(data):\n", - "\n", " init_notebook_mode() # initiate notebook for offline plot\n", "\n", " x = data[\"Centroid x\"]\n", From 7eae5b3dfb41ca256ad6ebfe9c3d6a232d100b78 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 021/577] Update .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 20fdcd08..8985098d 100644 --- a/.gitignore +++ b/.gitignore @@ -104,4 +104,4 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png -notebooks/instance_test.ipynb + From 6e39971b39fb926084f3ed71d82e8c25f68f8b6f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 022/577] Version bump --- napari_cellseg3d/__init__.py | 2 +- napari_cellseg3d/code_plugins/plugin_helper.py | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 6e2681e8..2c537225 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc1" +__version__ = "0.0.2rc2" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 9a83e0d8..999b7fa1 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -37,7 +37,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 9cff5fa8..41ee3a80 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc1 +version = 0.0.2rc2 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 0d30851e573a3ce69a4962fc5e6837baef074b86 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:33:40 +0200 Subject: [PATCH 023/577] Revert "Version bump" This reverts commit 6e39971b39fb926084f3ed71d82e8c25f68f8b6f. --- napari_cellseg3d/__init__.py | 2 +- napari_cellseg3d/code_plugins/plugin_helper.py | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 2c537225..6e2681e8 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc2" +__version__ = "0.0.2rc1" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 999b7fa1..9a83e0d8 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -37,7 +37,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 41ee3a80..9cff5fa8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc2 +version = 0.0.2rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From b6a3e49717575d69fc216b97c01ec72750e268ca Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 024/577] Updated project files --- pyproject.toml | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++ setup.cfg | 7 ++++-- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d596d54..5dec250c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,48 @@ +[project] +name = "napari_cellseg3d" +version = "0.0.2rc6" +authors = [ + {name = "Cyril Achard", email = "cyril.achard@epfl.ch"}, + {name = "Maxime Vidal", email = "maxime.vidal@epfl.ch"}, + {name = "Mackenzie Mathis", email = "mackenzie@post.harvard.edu"}, +] +requires-python = ">=3.8" +dependencies = [ + "numpy", + "napari[all]>=0.4.14", + "QtPy", + "opencv-python>=4.5.5", + "dask-image>=0.6.0", + "scikit-image>=0.19.2", + "matplotlib>=3.4.1", + "tifffile>=2022.2.9", + "imageio-ffmpeg>=0.4.5", + "torch>=1.11", + "monai[nibabel,einops]>=0.9.0", + "itk", + "tqdm", + "nibabel", + "scikit-image", + "pillow", + "pyclesperanto-prototype", + "tqdm", + "matplotlib", + "vispy>=0.9.6", +] + [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] + +[tool.setuptools.package-data] +"*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] + [tool.ruff] # Never enforce `E501` (line length violations). ignore = ["E501", "E741"] @@ -12,3 +53,23 @@ line-length = 79 [tool.isort] profile = "black" line_length = 79 + +[project.optional-dependencies] +dev = [ + "isort", + "black", + "ruff", +] +docs = [ + "sphinx", + "sphinx_autodoc_typehints", + "sphinx_rtd_theme", + "twine", +] +test = [ + "pytest", + "pytest_qt", + "coverage", + "tox", + "twine", +] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 9cff5fa8..d8adc6ae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc1 +version = 0.0.2rc6 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu @@ -65,7 +65,10 @@ install_requires = where = . [options.package_data] -napari-cellseg3d = napari.yaml +napari-cellseg3d = + res/*.png + code_models/models/pretrained/*.json + napari.yaml [options.entry_points] napari.manifest = From 4bbf74dea490cf05c125789909115e95f4ab4edf Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 09:45:17 +0200 Subject: [PATCH 025/577] Fixed missing parent error --- napari_cellseg3d/code_models/model_instance_seg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 30d147d3..fb1c10cb 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -451,7 +451,7 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent): + def __init__(self, widget_parent=None): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, From 803bfd16498c2e010a1dd867e1fc347597f8e616 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 026/577] Fixed wrong value in instance sliders --- .../code_models/model_instance_seg.py | 35 ++++++++++++------- .../code_plugins/plugin_model_inference.py | 1 + 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index fb1c10cb..add6693c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -139,6 +139,9 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -147,7 +150,7 @@ def voronoi_otsu( def binary_connected( - volume, + volume: np.array, thres=0.5, thres_small=3, ): @@ -161,8 +164,12 @@ def binary_connected( scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) """ + logger.debug( + f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" + ) + # if len(volume.shape) > 3: semantic = np.squeeze(volume) - foreground = semantic > thres # int(255 * thres) + foreground = np.where(semantic > thres, volume, 0) # int(255 * thres) segm = label(foreground) segm = remove_small_objects(segm, thres_small) @@ -205,6 +212,10 @@ def binary_watershed( rem_seed_thres (int): threshold for small seeds removal. Default : 3 """ + logger.debug( + f"Running watershed segmentation with thres_objects={thres_objects}, thres_seeding={thres_seeding}," + f" thres_small={thres_small} and rem_seed_thres={rem_seed_thres}" + ) semantic = np.squeeze(volume) seed_map = semantic > thres_seeding foreground = semantic > thres_objects @@ -410,8 +421,8 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( image, - self.sliders[0].value(), - self.sliders[1].value(), + self.sliders[0].slider_value, + self.sliders[1].slider_value, self.counters[0].value(), self.counters[1].value(), ) @@ -444,7 +455,7 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( - image, self.sliders[0].value(), self.counters[0].value() + image, self.sliders[0].slider_value, self.counters[0].value() ) @@ -504,7 +515,7 @@ def __init__(self, parent=None): """ super().__init__(parent) self.method_choice = ui.DropdownMenu( - INSTANCE_SEGMENTATION_METHOD_LIST.keys() + list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) self.methods = {} """Contains the instance of the method, with its name as key""" @@ -523,7 +534,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -533,8 +544,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() @@ -558,9 +571,7 @@ def run_method(self, volume): Returns: processed image from self._method """ - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() + method = self.methods[self.method_choice.currentText()] return method.run_method(volume) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 483679ef..fb6fb71c 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -184,6 +184,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) + self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## From 928dd48b7675e62e0fd4d2f90643c1cf827d118d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 027/577] Removing dask-image --- .gitignore | 1 + napari_cellseg3d/dev_scripts/convert.py | 5 +- napari_cellseg3d/dev_scripts/view_brain.py | 2 +- napari_cellseg3d/dev_scripts/view_sample.py | 2 +- napari_cellseg3d/utils.py | 113 ++++++++++---------- notebooks/full_plot.ipynb | 3 +- setup.cfg | 1 + 7 files changed, 64 insertions(+), 63 deletions(-) diff --git a/.gitignore b/.gitignore index 8985098d..19181d5f 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ notebooks/full_plot.html *.csv *.png +*.prof diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 129c16be..479a07dd 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,7 +2,7 @@ import os import numpy as np -from dask_image.imread import imread +from tifffile import imread from tifffile import imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" @@ -19,8 +19,7 @@ filenames.append(os.path.basename(filename)) # print(os.path.basename(filename)) for file in paths: - img = imread(file) - image = img.compute() + image = imread(file) image[image >= 1] = 1 image = image.astype(np.uint16) diff --git a/napari_cellseg3d/dev_scripts/view_brain.py b/napari_cellseg3d/dev_scripts/view_brain.py index e5879638..145d4e45 100644 --- a/napari_cellseg3d/dev_scripts/view_brain.py +++ b/napari_cellseg3d/dev_scripts/view_brain.py @@ -1,5 +1,5 @@ import napari -from dask_image.imread import imread +from tifffile import imread y = imread("/Users/maximevidal/Documents/3drawdata/wholebrain.tif") diff --git a/napari_cellseg3d/dev_scripts/view_sample.py b/napari_cellseg3d/dev_scripts/view_sample.py index 329944ac..8e87f85c 100644 --- a/napari_cellseg3d/dev_scripts/view_sample.py +++ b/napari_cellseg3d/dev_scripts/view_sample.py @@ -1,5 +1,5 @@ import napari -from dask_image.imread import imread +from tifffile import imread # Visual x = imread( diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 090e305b..95ddd319 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -4,9 +4,8 @@ from pathlib import Path import numpy as np -from dask_image.imread import imread as dask_imread -from pandas import DataFrame -from pandas import Series + +# from dask import delayed from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -17,7 +16,6 @@ LOGGER.setLevel(logging.DEBUG) # LOGGER.setLevel(logging.INFO) ############### - """ utils.py ==================================== @@ -278,51 +276,51 @@ def annotation_to_input(label_ermito): return anno -def check_csv(project_path, ext): - if not Path(Path(project_path) / Path(project_path).name).is_file(): - cols = [ - "project", - "type", - "ext", - "z", - "y", - "x", - "z_size", - "y_size", - "x_size", - "created_date", - "update_date", - "path", - "notes", - ] - df = DataFrame(index=[], columns=cols) - filename_pattern_original = Path(project_path) / Path( - f"dataset/Original_size/Original/*{ext}" - ) - images_original = dask_imread(filename_pattern_original) - z, y, x = images_original.shape - record = Series( - [ - Path(project_path).name, - "dataset", - ".tif", - 0, - 0, - 0, - z, - y, - x, - datetime.datetime.now(), - "", - Path(project_path) / Path("dataset/Original_size/Original"), - "", - ], - index=df.columns, - ) - df = df.append(record, ignore_index=True) - df.to_csv(Path(project_path) / Path(project_path).name) - else: - pass +# def check_csv(project_path, ext): +# if not Path(Path(project_path) / Path(project_path).name).is_file(): +# cols = [ +# "project", +# "type", +# "ext", +# "z", +# "y", +# "x", +# "z_size", +# "y_size", +# "x_size", +# "created_date", +# "update_date", +# "path", +# "notes", +# ] +# df = DataFrame(index=[], columns=cols) +# filename_pattern_original = Path(project_path) / Path( +# f"dataset/Original_size/Original/*{ext}" +# ) +# images_original = dask_imread(filename_pattern_original) +# z, y, x = images_original.shape +# record = Series( +# [ +# Path(project_path).name, +# "dataset", +# ".tif", +# 0, +# 0, +# 0, +# z, +# y, +# x, +# datetime.datetime.now(), +# "", +# Path(project_path) / Path("dataset/Original_size/Original"), +# "", +# ], +# index=df.columns, +# ) +# df = df.append(record, ignore_index=True) +# df.to_csv(Path(project_path) / Path(project_path).name) +# else: +# pass # def check_annotations_dir(project_path): @@ -457,7 +455,10 @@ def load_images(dir_or_path, filetype="", as_folder: bool = False): raise ValueError("If loading as a folder, filetype must be specified") if as_folder: - images_original = dask_imread(filename_pattern_original) + raise NotImplementedError( + "Loading as folder not implemented yet. Use napari to load as folder" + ) + # images_original = dask_imread(filename_pattern_original) else: images_original = tfl_imread( filename_pattern_original @@ -478,12 +479,12 @@ def load_images(dir_or_path, filetype="", as_folder: bool = False): # return base_label -def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): - images_label = load_images(mod_mask_dir, filetype, as_folder) - if as_folder: - images_label = images_label.compute() - base_label = images_label - return base_label +# def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): +# images_label = load_images(mod_mask_dir, filetype, as_folder) +# if as_folder: +# images_label = images_label.compute() +# base_label = images_label +# return base_label def save_stack(images, out_path, filetype=".png", check_warnings=False): diff --git a/notebooks/full_plot.ipynb b/notebooks/full_plot.ipynb index 857384d4..f804598e 100644 --- a/notebooks/full_plot.ipynb +++ b/notebooks/full_plot.ipynb @@ -10,8 +10,7 @@ "import matplotlib.pyplot as plt\n", "import os\n", "import numpy as np\n", - "from PIL import Image\n", - "from dask_image.imread import imread" + "from tifffile import imread" ] }, { diff --git a/setup.cfg b/setup.cfg index d8adc6ae..2420dd1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ package_dir = # add your package requirements here # the long list after monai is due to monai optional requirements... Not sure how to know in advance which readers it wil use +# FIXME remove dask install_requires = numpy napari[all]>=0.4.14 From 0ac83625ee23a06a69ba7c5b813b5ba59741cb72 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 028/577] Fixed erroneous dtype conversion --- napari_cellseg3d/code_models/model_instance_seg.py | 13 +++++++++++-- napari_cellseg3d/code_plugins/plugin_convert.py | 12 ++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index add6693c..412c87d7 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -138,12 +138,12 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - semantic = np.squeeze(volume) + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) instance = cle.voronoi_otsu_labeling( - semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) return np.array(instance) @@ -492,6 +492,15 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + + ################ + # For debugging + # import napari + # view = napari.Viewer() + # view.add_image(image) + # napari.run() + ################ + return self.function( image, self.counters[0].value(), diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 37051fcc..db755317 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -151,7 +151,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) save_layer( @@ -169,7 +169,7 @@ def _start(self): elif self.folder_choice.isChecked(): if len(self.images_filepaths) != 0: images = [ - utils.resize(np.array(imread(file), dtype=np.int16), zoom) + utils.resize(np.array(imread(file)), zoom) for file in self.images_filepaths ] save_folder( @@ -250,7 +250,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( @@ -331,7 +331,7 @@ def _start(self): if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) semantic = to_semantic(data) save_layer( @@ -416,7 +416,7 @@ def _start(self): if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) instance = self.instance_widgets.run_method(data) save_layer( @@ -511,7 +511,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( From e194307184804f002c99c71e0079be97a86f53a3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 029/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index f0fac98a..24f4e867 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,3 +1,7 @@ +from pathlib import Path +from tifffile import imread +import numpy as np + from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS @@ -6,9 +10,15 @@ def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + image = imread(im_path) + view.add_image(image) + view.add_labels(image.astype(np.uint8)) + view.window.add_dock_widget(widget) for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): widget.utils_choice.setCurrentIndex(i) assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + widget.utils_widgets[i]._start() From e4be4d49179b183866a556ad10add4f107cf28a7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:38:13 +0200 Subject: [PATCH 030/577] Temporary test action patch --- .github/workflows/test_and_deploy.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 11b8776f..ef5c4e0e 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -8,12 +8,14 @@ on: branches: - main - npe2 + - cy/voronoi-otsu tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - npe2 + - cy/voronoi-otsu workflow_dispatch: jobs: From 1193190d7bd767424d25c37fdc29e6dbea3d22e4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:50:16 +0200 Subject: [PATCH 031/577] Update plugin_convert.py --- napari_cellseg3d/code_plugins/plugin_convert.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index db755317..ed1a43df 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -35,7 +35,7 @@ def save_folder(results_path, folder_name, images, image_paths): image_paths: list of filenames of images """ results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False) + results_folder.mkdir(exist_ok=False, parents=True) for file, image in zip(image_paths, images): path = results_folder / Path(file).name @@ -144,7 +144,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -243,7 +243,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -325,7 +325,7 @@ def _build(self): ) def _start(self): - Path(self.results_path).mkdir(exist_ok=True) + Path(self.results_path).mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -410,7 +410,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -504,7 +504,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.binarize_counter.value() if self.layer_choice: From baaaae00a272bae9af12bfd086dcecd100681cc7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:02:47 +0200 Subject: [PATCH 032/577] Update tox.ini Added pocl for testing on GH Actions --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 139b04f5..7e48111d 100644 --- a/tox.ini +++ b/tox.ini @@ -35,6 +35,7 @@ deps = magicgui pytest-qt qtpy + pocl ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From bbb28230821b76be77f2ca274477746af59dfee2 Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 033/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 7e48111d..4abf987c 100644 --- a/tox.ini +++ b/tox.ini @@ -35,7 +35,7 @@ deps = magicgui pytest-qt qtpy - pocl + pocl-binary-distribution ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 1307cae82e33570fbc431a8b32211834d569963d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:18:52 +0200 Subject: [PATCH 034/577] Found existing pocl --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 4abf987c..2d8953f6 100644 --- a/tox.ini +++ b/tox.ini @@ -35,7 +35,7 @@ deps = magicgui pytest-qt qtpy - pocl-binary-distribution + pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 48a097e685b0b9b295169583d04f245fbf5bd067 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:41:23 +0200 Subject: [PATCH 035/577] Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime --- napari_cellseg3d/_tests/test_plugin_utils.py | 5 +++++ tox.ini | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 24f4e867..df42c1d6 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -21,4 +21,9 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + widget.utils_widgets[i]._start() diff --git a/tox.ini b/tox.ini index 2d8953f6..f70b2357 100644 --- a/tox.ini +++ b/tox.ini @@ -35,7 +35,7 @@ deps = magicgui pytest-qt qtpy - pyopencl[pocl] +; pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 16940ea18879d7c910c42b83ca9d59755d2505a8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 036/577] Relabeling tests --- .gitignore | 6 +- napari_cellseg3d/_tests/res/test_labels.tif | Bin 0 -> 2026 bytes .../_tests/test_labels_correction.py | 51 ++++++++++ .../dev_scripts/artefact_labeling.py | 93 ++++++++---------- .../dev_scripts/correct_labels.py | 75 +++++++++----- 5 files changed, 152 insertions(+), 73 deletions(-) create mode 100644 napari_cellseg3d/_tests/res/test_labels.tif create mode 100644 napari_cellseg3d/_tests/test_labels_correction.py diff --git a/.gitignore b/.gitignore index 19181d5f..15fa210a 100644 --- a/.gitignore +++ b/.gitignore @@ -104,5 +104,9 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png - *.prof + +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/res/test_labels.tif b/napari_cellseg3d/_tests/res/test_labels.tif new file mode 100644 index 0000000000000000000000000000000000000000..0486d789ea658acc32616b40869833accf8d01d7 GIT binary patch literal 2026 zcmcK5yJ}QX6b9gPW+oRk7ZaUC6EDMfA4AYa#WzSNSc-*3f&q(wHX^pxK7x-TK7$V+ zh-hgcUQztNum?}HQam9)Yn^rd*V>ys8yll)x~i)As;YZc9c?nG8+xbi?%D^jcZTs? z;A_lNkw1zQI~mBsOB}G_#)iYWBJ~`{jpd=(^f$j8o7U4Q+J#zTzQ_J1_z;*uteC68 z$zUP)5+6=ue*NfXR{vChyE>l&{pDQ@Rs-|ldNz=U59v(k#{#wVv17fKBh6$ldYFA! zj2TD_w8=&Bcb7Z$2DI=OnVnep55ES3J|ZCRZ7^|q`;Z@w+f_hcAf uO7Fq{VSFjiRU3?7w#N8*ON^i72d14J-{`ip<7-oGF@Dt&<6PlCcKj3h Date: Sun, 23 Apr 2023 14:06:43 +0200 Subject: [PATCH 037/577] Added new pre-commit hooks --- .pre-commit-config.yaml | 43 ++++++++++++----------------------------- pyproject.toml | 3 ++- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1e22fb1..da16a3b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,44 +1,25 @@ repos: -# - repo: https://github.com/pre-commit/pre-commit-hooks -# rev: v4.0.1 -# hooks: -# - id: check-docstring-first -# - id: end-of-file-fixer -# - id: trailing-whitespace -# - repo: https://github.com/asottile/setup-cfg-fmt -# rev: v1.20.0 -# hooks: -# - id: setup-cfg-fmt -# - repo: https://github.com/PyCQA/flake8 -# rev: 4.0.1 -# hooks: -# - id: flake8 -# additional_dependencies: [flake8-typing-imports>=1.9.0] -# - repo: https://github.com/myint/autoflake -# rev: v1.4 -# hooks: -# - id: autoflake -# args: ["--in-place", "--remove-all-unused-imports"] -# - repo: https://github.com/PyCQA/isort -# rev: 5.10.1 -# hooks: -# - id: isort + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-docstring-first + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.257' + rev: 'v0.0.262' hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.3.0 hooks: - id: black args: [--line-length=79] -# - repo: https://github.com/asottile/pyupgrade -# rev: v2.29.1 -# hooks: -# - id: pyupgrade -# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks rev: v0.3.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 5dec250c..d2a2adbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "isort", "black", "ruff", + "pre-commit", ] docs = [ "sphinx", @@ -72,4 +73,4 @@ test = [ "coverage", "tox", "twine", -] \ No newline at end of file +] From f2bdf5513e49c66bf97b594a03e3995211086300 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 038/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da16a3b9..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,13 +2,14 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: check-docstring-first +# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From ad64f2e9d635dc27f1347585bbea29bcfcf02dc8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 039/577] Run full suite of pre-commit hooks --- .github/workflows/preview_metadata.yml | 2 +- .isort.cfg | 2 +- .napari/DESCRIPTION.md | 22 +++++++++---------- README.md | 4 ++-- docs/index.rst | 1 - docs/res/code/interface.rst | 19 ---------------- docs/res/code/model_framework.rst | 2 +- docs/res/code/model_workers.rst | 5 ----- docs/res/code/plugin_base.rst | 2 +- docs/res/code/plugin_convert.rst | 2 +- docs/res/code/plugin_crop.rst | 4 ---- docs/res/code/plugin_metrics.rst | 2 +- docs/res/code/plugin_model_training.rst | 2 +- docs/res/code/plugin_review.rst | 3 --- docs/res/code/plugin_review_dock.rst | 2 -- docs/res/code/utils.rst | 21 ------------------ docs/res/guides/metrics_module_guide.rst | 4 ---- docs/res/guides/utils_module_guide.rst | 6 ----- napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/_tests/pytest.ini | 2 +- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/model_instance_seg.py | 3 +-- .../pretrained/pretrained_model_urls.json | 2 +- .../code_models/models/unet/model.py | 6 ++--- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 2 +- napari_cellseg3d/napari.yaml | 1 - tox.ini | 6 ++--- 30 files changed, 55 insertions(+), 114 deletions(-) diff --git a/.github/workflows/preview_metadata.yml b/.github/workflows/preview_metadata.yml index 98651d4b..77c67326 100644 --- a/.github/workflows/preview_metadata.yml +++ b/.github/workflows/preview_metadata.yml @@ -18,4 +18,4 @@ jobs: - name: napari hub Preview Page Builder uses: chanzuckerberg/napari-hub-preview-action@v0.1.6 with: - hub-ref: main \ No newline at end of file + hub-ref: main diff --git a/.isort.cfg b/.isort.cfg index 9b5f551d..2b497c7c 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,6 +1,6 @@ [settings] force_single_line = True -force_sort_within_sections = False +force_sort_within_sections = False lexicographical = True single_line_exclusions = ('typing',) order_by_type = False diff --git a/.napari/DESCRIPTION.md b/.napari/DESCRIPTION.md index 5d0c4267..9a3143bb 100644 --- a/.napari/DESCRIPTION.md +++ b/.napari/DESCRIPTION.md @@ -7,8 +7,8 @@ the functionality of your plugin. Its content will be rendered on your plugin's napari hub page. The sections below are given as a guide for the flow of information only, and -are in no way prescriptive. You should feel free to merge, remove, add and -rename sections at will to make this document work best for your plugin. +are in no way prescriptive. You should feel free to merge, remove, add and +rename sections at will to make this document work best for your plugin. --> ## Description @@ -17,13 +17,13 @@ A napari plugin for 3D cell segmentation: training, inference, and data review. A detailed walk-through and description is available [on the documentation website](https://adaptivemotorcontrollab.github.io/cellseg3d-docs/res/welcome.html). -## Additional Install Steps +## Additional Install Steps **Python >= 3.8 required** @@ -110,8 +110,8 @@ for the majority of plugins. They will include instructions to pip install, and to install via napari itself. Most plugins can be installed out-of-the-box by just specifying the package requirements -over in `setup.cfg`. However, if your plugin has any more complex dependencies, or -requires any additional preparation before (or after) installation, you should add +over in `setup.cfg`. However, if your plugin has any more complex dependencies, or +requires any additional preparation before (or after) installation, you should add this information here. --> ## Getting Help @@ -131,7 +131,7 @@ here. \ No newline at end of file +--> diff --git a/README.md b/README.md index c14e1a2c..05b49dcb 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Note : we recommend using conda to create a new environment for the plugin. conda create --name python=3.8 napari-cellseg3d conda activate napari-cellseg3d -You can install `napari-cellseg3d` via [pip]: +You can install `napari-cellseg3d` via [pip]: pip install napari-cellseg3d @@ -123,7 +123,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/docs/index.rst b/docs/index.rst index 90f430a0..7e809fbe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -54,4 +54,3 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` * :ref:`search` - diff --git a/docs/res/code/interface.rst b/docs/res/code/interface.rst index 3bc4f914..8bf43e04 100644 --- a/docs/res/code/interface.rst +++ b/docs/res/code/interface.rst @@ -105,22 +105,3 @@ toggle_visibility open_file_dialog ************************************** .. autofunction:: napari_cellseg3d.interface::open_file_dialog - - - - - - - - - - - - - - - - - - - diff --git a/docs/res/code/model_framework.rst b/docs/res/code/model_framework.rst index 59f6d004..a3483f5a 100644 --- a/docs/res/code/model_framework.rst +++ b/docs/res/code/model_framework.rst @@ -20,4 +20,4 @@ Attributes ********************* .. autoclass:: napari_cellseg3d.code_models.model_framework::ModelFramework - :members: _viewer, worker, docked_widgets, images_filepaths, labels_filepaths, results_path \ No newline at end of file + :members: _viewer, worker, docked_widgets, images_filepaths, labels_filepaths, results_path diff --git a/docs/res/code/model_workers.rst b/docs/res/code/model_workers.rst index 914f6507..85f8da29 100644 --- a/docs/res/code/model_workers.rst +++ b/docs/res/code/model_workers.rst @@ -42,8 +42,3 @@ Methods .. autoclass:: napari_cellseg3d.code_models.model_workers::TrainingWorker :members: __init__, log, train :noindex: - - - - - diff --git a/docs/res/code/plugin_base.rst b/docs/res/code/plugin_base.rst index af4a4f4e..f1015c7e 100644 --- a/docs/res/code/plugin_base.rst +++ b/docs/res/code/plugin_base.rst @@ -35,4 +35,4 @@ Attributes ********************* .. autoclass:: napari_cellseg3d.code_plugins.plugin_base::BasePluginFolder - :members: _viewer, images_filepaths, labels_filepaths, results_path \ No newline at end of file + :members: _viewer, images_filepaths, labels_filepaths, results_path diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 43b8c7be..03944510 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -42,4 +42,4 @@ save_layer show_result **************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result \ No newline at end of file +.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result diff --git a/docs/res/code/plugin_crop.rst b/docs/res/code/plugin_crop.rst index cb311d74..f52fd025 100644 --- a/docs/res/code/plugin_crop.rst +++ b/docs/res/code/plugin_crop.rst @@ -23,7 +23,3 @@ Attributes .. autoclass:: napari_cellseg3d.code_plugins.plugin_crop::Cropping :members: _viewer, image_path, label_path - - - - diff --git a/docs/res/code/plugin_metrics.rst b/docs/res/code/plugin_metrics.rst index f9014edb..29c2fe25 100644 --- a/docs/res/code/plugin_metrics.rst +++ b/docs/res/code/plugin_metrics.rst @@ -19,4 +19,4 @@ Attributes ********************* .. autoclass:: napari_cellseg3d.code_plugins.plugin_metrics::MetricsUtils - :members: _viewer, layout, canvas, plots \ No newline at end of file + :members: _viewer, layout, canvas, plots diff --git a/docs/res/code/plugin_model_training.rst b/docs/res/code/plugin_model_training.rst index a531b877..870dfd14 100644 --- a/docs/res/code/plugin_model_training.rst +++ b/docs/res/code/plugin_model_training.rst @@ -20,4 +20,4 @@ Attributes ********************* .. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer - :members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot \ No newline at end of file + :members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot diff --git a/docs/res/code/plugin_review.rst b/docs/res/code/plugin_review.rst index f61e5661..69397400 100644 --- a/docs/res/code/plugin_review.rst +++ b/docs/res/code/plugin_review.rst @@ -24,6 +24,3 @@ Attributes .. autoclass:: napari_cellseg3d.code_plugins.plugin_review::Reviewer :members: _viewer, image_path, label_path - - - diff --git a/docs/res/code/plugin_review_dock.rst b/docs/res/code/plugin_review_dock.rst index 597b30e9..3aa0f6ae 100644 --- a/docs/res/code/plugin_review_dock.rst +++ b/docs/res/code/plugin_review_dock.rst @@ -19,5 +19,3 @@ Attributes .. autoclass:: napari_cellseg3d.code_plugins.plugin_review_dock::Datamanager :members: viewer - - diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index 15d7fc1d..e90ee7e0 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -66,24 +66,3 @@ load_images format_Warning ************************************** .. autofunction:: napari_cellseg3d.utils::format_Warning - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/res/guides/metrics_module_guide.rst b/docs/res/guides/metrics_module_guide.rst index 4a7b0c60..98899ad9 100644 --- a/docs/res/guides/metrics_module_guide.rst +++ b/docs/res/guides/metrics_module_guide.rst @@ -47,7 +47,3 @@ Source code * :doc:`../code/plugin_base` * :doc:`../code/plugin_metrics` - - - - diff --git a/docs/res/guides/utils_module_guide.rst b/docs/res/guides/utils_module_guide.rst index fd9f7401..407ae710 100644 --- a/docs/res/guides/utils_module_guide.rst +++ b/docs/res/guides/utils_module_guide.rst @@ -36,9 +36,3 @@ Source code * :doc:`../code/plugin_base` * :doc:`../code/plugin_convert` - - - - - - diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 814cca2e..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 \ No newline at end of file +qt_api=pyqt5 diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index df42c1d6..cbfd97b2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 412c87d7..c72bafe9 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,8 +15,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -492,7 +492,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari diff --git a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json index 9331484b..cd0782fb 100644 --- a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json @@ -4,4 +4,4 @@ "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet.tar.gz", "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz", "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" -} \ No newline at end of file +} diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index a31cc580..6cc76be6 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -57,7 +57,7 @@ def __init__( conv_kernel_size=3, pool_kernel_size=2, conv_padding=1, - **kwargs + **kwargs, ): super(Abstract3DUNet, self).__init__() @@ -153,7 +153,7 @@ def __init__( num_levels=4, is_segmentation=True, conv_padding=1, - **kwargs + **kwargs, ): super(UNet3D, self).__init__( in_channels=in_channels, @@ -166,5 +166,5 @@ def __init__( num_levels=num_levels, is_segmentation=is_segmentation, conv_padding=conv_padding, - **kwargs + **kwargs, ) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index a972fa69..ee9919b6 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,7 @@ +import napari import numpy as np import pandas as pd from tqdm import tqdm -import napari from napari_cellseg3d.utils import LOGGER as log diff --git a/napari_cellseg3d/napari.yaml b/napari_cellseg3d/napari.yaml index 82058b9e..8513a957 100644 --- a/napari_cellseg3d/napari.yaml +++ b/napari_cellseg3d/napari.yaml @@ -40,4 +40,3 @@ contributions: - command: napari-cellseg3d.help display_name: Help/About... - diff --git a/tox.ini b/tox.ini index f70b2357..784a056d 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ isolated_build=true [gh-actions] python = 3.8: py38 - + [gh-actions:env] PLATFORM = ubuntu-latest: linux @@ -14,11 +14,11 @@ PLATFORM = ; windows-latest: windows [testenv] -platform = +platform = linux: linux ; macos: darwin ; windows: win32 -passenv = +passenv = CI PYTHONPATH GITHUB_ACTIONS From a84fefdcb41b8a5c2c0927940b920c2856794eb2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 040/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- docs/res/code/plugin_convert.rst | 5 - .../_tests/test_plugin_inference.py | 1 + .../code_models/model_instance_seg.py | 276 +++++++++++++++++- napari_cellseg3d/code_models/model_workers.py | 32 +- .../code_plugins/plugin_convert.py | 155 +--------- .../code_plugins/plugin_model_inference.py | 25 +- napari_cellseg3d/config.py | 38 +-- napari_cellseg3d/interface.py | 21 +- requirements.txt | 1 + 9 files changed, 306 insertions(+), 248 deletions(-) diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index c7dc7df9..03944510 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -19,11 +19,6 @@ ToSemanticUtils .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToSemanticUtils :members: __init__ -InstanceWidgets -********************************** -.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::InstanceWidgets - :members: __init__, run_method - ToInstanceUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToInstanceUtils diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..584ffd3b 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,6 +8,7 @@ from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e4bec4ea..2cb7728b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -6,20 +6,65 @@ import numpy as np -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area +import pyclesperanto_prototype as cle +from qtpy.QtWidgets import QWidget + from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed + +from skimage.filters import thresholding +from skimage.transform import resize + +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes from tifffile import imread +from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from napari_cellseg3d.utils import sphericity_volume_area +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -50,18 +95,43 @@ def get_dict(self): def threshold(volume, thresh): + """Remove all values smaller than the specified threshold in the volume""" im = np.squeeze(volume) binary = im > thresh return np.where(binary, im, np.zeros_like(im)) +def voronoi_otsu( + volume: np.ndarray, + spot_sigma: float, + outline_sigma: float, + remove_small_size: float, +): + """ + Voronoi-Otsu labeling from pyclesperanto. + BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase + https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: + volume (np.ndarray): volume to segment + spot_sigma (float): parameter determining how close detected objects can be + outline_sigma (float): determines the smoothness of the segmentation + remove_small_size (float): remove all objects smaller than the specified size in pixels + + Returns: + Instance segmentation labels from Voronoi-Otsu method + """ + semantic = np.squeeze(volume) + instance = cle.voronoi_otsu_labeling( + semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + ) + # instance = remove_small_objects(instance, remove_small_size) + return instance + + def binary_connected( volume, thres=0.5, thres_small=3, - # scale_factors=(1.0, 1.0, 1.0), - *args, - **kwargs, ): r"""Convert binary foreground probability maps to instance masks via connected-component labeling. @@ -70,7 +140,6 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) """ semantic = np.squeeze(volume) foreground = semantic > thres # int(255 * thres) @@ -97,12 +166,9 @@ def binary_connected( def binary_watershed( volume, thres_objects=0.3, - thres_small=10, thres_seeding=0.9, - # scale_factors=(1.0, 1.0, 1.0), + thres_small=10, rem_seed_thres=3, - *args, - **kwargs, ): r"""Convert binary foreground probability maps to instance masks via watershed segmentation algorithm. @@ -113,10 +179,9 @@ def binary_watershed( Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. - thres_seeding (float): threshold for seeding. Default: 0.98 thres_objects (float): threshold for foreground objects. Default: 0.3 + thres_seeding (float): threshold for seeding. Default: 0.9 thres_small (int): size threshold of small objects removal. Default: 10 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) rem_seed_thres (int): threshold for small seeds removal. Default : 3 """ semantic = np.squeeze(volume) @@ -193,7 +258,7 @@ def to_instance(image, is_file_path=False): result = binary_watershed( image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 - ) # TODO add params + ) # FIXME add params from utils plugin return result @@ -283,3 +348,188 @@ def fill(lst, n=len(properties) - 1): ratio, fill([len(properties)]), ) + + +class Watershed(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Watershed", + function=binary_watershed, + num_sliders=2, + num_counters=2, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(50) + + self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].tooltips = "Probability threshold for seeding" + self.sliders[1].setValue(90) + + self.counters[0].label.setText("Small object removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(30) + + self.counters[1].label.setText("Small seed removal") + self.counters[1].tooltips = ( + "Volume/size threshold for small seeds removal." + "\nAll seeds with a volume/size below this value will be removed." + ) + self.counters[1].setValue(3) + + def run_method(self, image): + return self.function( + image, + self.sliders[0].value(), + self.sliders[1].value(), + self.counters[0].value(), + self.counters[1].value(), + ) + + +class ConnectedComponents(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Connected Components", + function=binary_connected, + num_sliders=1, + num_counters=1, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(80) + + self.counters[0].label.setText("Small objects removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(3) + + def run_method(self, image): + return self.function( + image, self.sliders[0].value(), self.counters[0].value() + ) + + +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Voronoi-Otsu", + function=voronoi_otsu, + num_sliders=0, + num_counters=3, + ) + self.counters[0].label.setText("Spot sigma") + self.counters[ + 0 + ].tooltips = "Determines how close detected objects can be" + self.counters[0].setMaximum(100) + self.counters[0].setValue(2) + + self.counters[1].label.setText("Outline sigma") + self.counters[ + 1 + ].tooltips = "Determines the smoothness of the segmentation" + self.counters[1].setMaximum(100) + self.counters[1].setValue(2) + + self.counters[2].label.setText("Small object removal") + self.counters[2].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + + def run_method(self, image): + return self.function( + image, + self.counters[0].value(), + self.counters[1].value(), + self.counters[2].value(), + ) + + +class InstanceWidgets(QWidget): + """ + Base widget with several sliders, for use in instance segmentation parameters + """ + + def __init__(self, parent=None): + """ + Creates an InstanceWidgets widget + + Args: + parent: parent widget + """ + super().__init__(parent) + + self.method_choice = ui.DropdownMenu( + INSTANCE_SEGMENTATION_METHOD_LIST.keys() + ) + self.methods = [] + self.instance_widgets = {} + + self.method_choice.currentTextChanged.connect(self._set_visibility) + self._build() + + def _build(self): + + group = ui.GroupedWidget("Instance segmentation") + group.layout.addWidget(self.method_choice) + + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + self.instance_widgets[name] = [] + if len(method().sliders) > 0: + for slider in method().sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method().counters) > 0: + for counter in method().counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + + self.setLayout(group.layout) + self._set_visibility() + + def _set_visibility(self): + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + + for widget in self.instance_widgets[method.name]: + widget.set_visibility(True) + + for key in self.instance_widgets.keys(): + if key != method.name: + for widget in self.instance_widgets[key]: + widget.set_visibility(False) + + def run_method(self, volume): + """ + Calls instance function with chosen parameters + Args: + volume: image data to run method on + + Returns: processed image from self._method + """ + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + return method.run_method(volume) + + +INSTANCE_SEGMENTATION_METHOD_LIST = { + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, +} diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 8b1caf8e..a43077e4 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -50,12 +50,6 @@ from napari_cellseg3d import config from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_connected, -) -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_watershed, -) from napari_cellseg3d.code_models.model_instance_seg import ImageStats from napari_cellseg3d.code_models.model_instance_seg import volume_stats @@ -603,30 +597,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - threshold = ( - self.config.post_process_config.instance.threshold.threshold_value - ) - size_small = ( - self.config.post_process_config.instance.small_object_removal_threshold.threshold_value - ) - method_name = self.config.post_process_config.instance.method - - if method_name == "Watershed": # FIXME use dict in config instead - - def method(image): - return binary_watershed(image, threshold, size_small) - - elif method_name == "Connected components": - - def method(image): - return binary_connected(image, threshold, size_small) - - else: - raise NotImplementedError( - "Selected instance segmentation method is not defined" - ) - - instance_labels = method(to_instance) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 5560b7b9..f461b46f 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -16,6 +16,7 @@ ) from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -358,160 +359,6 @@ def _start(self): self.images_filepaths, ) - -class InstanceWidgets(QWidget): - """ - Base widget with several sliders, for use in instance segmentation parameters - """ - - def __init__(self, parent=None): - """ - Creates an InstanceWidgets widget - - Args: - parent: parent widget - """ - super().__init__(parent) - - self.method_choice = ui.DropdownMenu( - config.INSTANCE_SEGMENTATION_METHOD_LIST.keys() - ) - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ] - - self.method_choice.currentTextChanged.connect(self._show_connected) - self.method_choice.currentTextChanged.connect(self._show_watershed) - - self.threshold_slider1 = ui.Slider( - lower=0, - upper=100, - default=50, - divide_factor=100.0, - step=5, - text_label="Probability threshold :", - ) - """Base prob. threshold""" - self.threshold_slider2 = ui.Slider( - lower=0, - upper=100, - default=90, - divide_factor=100.0, - step=5, - text_label="Probability threshold (seeding) :", - ) - """Second prob. thresh. (seeding)""" - - self.counter1 = ui.IntIncrementCounter( - upper=100, - default=10, - step=5, - label="Small object removal (pxs) :", - ) - """Small obj. rem.""" - - self.counter2 = ui.IntIncrementCounter( - upper=100, - default=3, - step=5, - label="Small seed removal (pxs) :", - ) - """Small seed rem.""" - - self._build() - - def run_method(self, volume): - """ - Calls instance function with chosen parameters - Args: - volume: image data to run method on - - Returns: processed image from self._method - """ - return self._method( - volume, - self.threshold_slider1.slider_value, - self.counter1.value(), - self.threshold_slider2.slider_value, - self.counter2.value(), - ) - - def _build(self): - group = ui.GroupedWidget("Instance segmentation") - - ui.add_widgets( - group.layout, - [ - self.method_choice, - self.threshold_slider1.container, - self.threshold_slider2.container, - self.counter1.label, - self.counter1, - self.counter2.label, - self.counter2, - ], - ) - - self.setLayout(group.layout) - self._set_tooltips() - - def _set_tooltips(self): - self.method_choice.setToolTip( - "Choose which method to use for instance segmentation" - "\nConnected components : all separated objects will be assigned an unique ID. " - "Robust but will not work correctly with adjacent/touching objects\n" - "Watershed : assigns objects ID based on the probability gradient surrounding an object. " - "Requires the model to surround objects in a gradient;" - " can possibly correctly separate unique but touching/adjacent objects." - ) - self.threshold_slider1.tooltips = ( - "All objects below this probability will be ignored (set to 0)" - ) - self.counter1.setToolTip( - "Will remove all objects smaller (in volume) than the specified number of pixels" - ) - self.threshold_slider2.tooltips = ( - "All seeds below this probability will be ignored (set to 0)" - ) - self.counter2.setToolTip( - "Will remove all seeds smaller (in volume) than the specified number of pixels" - ) - - def _show_watershed(self): - name = "Watershed" - if self.method_choice.currentText() == name: - self._show_slider1() - self._show_slider2() - self._show_counter1() - self._show_counter2() - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_connected(self): - name = "Connected components" - if self.method_choice.currentText() == name: - self._show_slider1() - self._show_slider2(False) - self._show_counter1() - self._show_counter2(False) - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_slider1(self, is_visible: bool = True): - self.threshold_slider1.container.setVisible(is_visible) - - def _show_slider2(self, is_visible: bool = True): - self.threshold_slider2.container.setVisible(is_visible) - - def _show_counter1(self, is_visible: bool = True): - self.counter1.setVisible(is_visible) - self.counter1.label.setVisible(is_visible) - - def _show_counter2(self, is_visible: bool = True): - self.counter2.setVisible(is_visible) - self.counter2.label.setVisible(is_visible) - - class ToInstanceUtils(BasePluginFolder): """ Widget to convert semantic labels to instance labels diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 0dca3ec8..4a7ab671 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -12,7 +12,11 @@ from napari_cellseg3d.code_models.model_framework import ModelFramework from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_plugins.plugin_convert import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -77,9 +81,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): config.InferenceWorkerConfig() ) """InferenceWorkerConfig class from config.py""" - self.instance_config: config.InstanceSegConfig = ( - config.InstanceSegConfig() - ) + self.instance_config: InstanceMethod """InstanceSegConfig class from config.py""" self.post_process_config: config.PostProcessConfig = ( config.PostProcessConfig() @@ -551,18 +553,9 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - instance_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.threshold_slider1.slider_value - ) - instance_small_object_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.counter1.value() - ) - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.method_choice.currentText(), - threshold=instance_thresh_config, - small_object_removal_threshold=instance_small_object_thresh_config, - ) + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] self.post_process_config = config.PostProcessConfig( zoom=zoom_config, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 57c65bac..74cbf81d 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,22 +8,20 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_connected, -) -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_watershed, -) -from napari_cellseg3d.code_models.models import ( - model_SegResNet as SegResNet, -) -from napari_cellseg3d.code_models.models import ( - model_SwinUNetR as SwinUNetR, -) + +# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet +from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR from napari_cellseg3d.code_models.models import ( model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -40,10 +38,6 @@ # "test" : DO NOT USE, reserved for testing } -INSTANCE_SEGMENTATION_METHOD_LIST = { - "Watershed": binary_watershed, - "Connected components": binary_connected, -} WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") @@ -127,21 +121,11 @@ class Zoom: zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: str = None - threshold: Thresholding = Thresholding(enabled=False, threshold_value=0.85) - small_object_removal_threshold: Thresholding = Thresholding( - enabled=True, threshold_value=20 - ) - - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d23199ee..d2f8d787 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -12,6 +12,7 @@ # from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QObject from qtpy.QtCore import Qt +# from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices @@ -499,9 +500,12 @@ def __init__( self._build_container() - def _build_container(self): - self.container.layout + def set_visibility(self, visible: bool): + self.container.setVisible(visible) + self.setVisible(visible) + self.text_label.setVisible(visible) + def _build_container(self): if self.text_label is not None: add_widgets( self.container.layout, @@ -1021,7 +1025,7 @@ class DoubleIncrementCounter(QDoubleSpinBox): def __init__( self, lower: Optional[float] = 0.0, - upper: Optional[float] = 10.0, + upper: Optional[float] = 1000.0, default: Optional[float] = 0.0, step: Optional[float] = 1.0, parent: Optional[QWidget] = None, @@ -1045,6 +1049,13 @@ def __init__( if label is not None: self.label = make_label(name=label) + self.valueChanged.connect(self._update_step) + + def _update_step(self): + if self.value() < 0.9: + self.setSingleStep(0.1) + else: + self.setSingleStep(1) @property def tooltips(self): @@ -1081,6 +1092,10 @@ def make_n( cls, n, lower, upper, default, step, parent, fixed ) + def set_visibility(self, visible: bool): + self.setVisible(visible) + self.label.setVisible(visible) + class IntIncrementCounter(QSpinBox): """Class implementing a number counter with increments (spin box) for int.""" diff --git a/requirements.txt b/requirements.txt index 3ba73405..739b7aa3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From e5b40f55690fd40f28c342a18d3b21c99eb43f6b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 041/577] Disabled small removal in Voronoi-Otsu --- .../code_models/model_instance_seg.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 2cb7728b..81b1744b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -105,7 +105,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - remove_small_size: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -115,11 +115,12 @@ def voronoi_otsu( volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - remove_small_size (float): remove all objects smaller than the specified size in pixels + Returns: Instance segmentation labels from Voronoi-Otsu method """ + # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma @@ -427,7 +428,7 @@ def __init__(self): name="Voronoi-Otsu", function=voronoi_otsu, num_sliders=0, - num_counters=3, + num_counters=2, ) self.counters[0].label.setText("Spot sigma") self.counters[ @@ -443,18 +444,19 @@ def __init__(self): self.counters[1].setMaximum(100) self.counters[1].setValue(2) - self.counters[2].label.setText("Small object removal") - self.counters[2].tooltips = ( - "Volume/size threshold for small object removal." - "\nAll objects with a volume/size below this value will be removed." - ) + # self.counters[2].label.setText("Small object removal") + # self.counters[2].tooltips = ( + # "Volume/size threshold for small object removal." + # "\nAll objects with a volume/size below this value will be removed." + # ) + # self.counters[2].setValue(30) def run_method(self, image): return self.function( image, self.counters[0].value(), self.counters[1].value(), - self.counters[2].value(), + # self.counters[2].value(), ) From d40fd9567013a9ddce4c929aacad69961e908fca Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 042/577] Added new docs for instance seg --- docs/res/code/model_instance_seg.rst | 23 +++++++++++++++++++ .../code_models/model_instance_seg.py | 22 ++++++++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst index e4146ec1..3b323173 100644 --- a/docs/res/code/model_instance_seg.rst +++ b/docs/res/code/model_instance_seg.rst @@ -1,6 +1,29 @@ model_instance_seg.py =========================================== +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu + :members: __init__ + Functions ------------- diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 81b1744b..7fd33317 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -37,6 +37,14 @@ def __init__( num_sliders: int, num_counters: int, ): + """ + Methods for instance segmentation + Args: + name: Name of the instance segmentation method (for UI) + function: Function to use for instance segmentation + num_sliders: Number of Slider UI elements needed to set the parameters of the function + num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + """ self.name = name self.function = function self.counters: List[ui.DoubleIncrementCounter] = [] @@ -176,7 +184,7 @@ def binary_watershed( Note: This function uses the `skimage.segmentation.watershed `_ - function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays. + function that converts the input image into ``np.float64`` data type for processing. Therefore, please make sure enough memory is allocated when handling large arrays. Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. @@ -352,6 +360,8 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod, metaclass=Singleton): + """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" + def __init__(self): super().__init__( name="Watershed", @@ -395,6 +405,8 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod, metaclass=Singleton): + """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" + def __init__(self): super().__init__( name="Connected Components", @@ -423,6 +435,8 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" + def __init__(self): super().__init__( name="Voronoi-Otsu", @@ -430,14 +444,14 @@ def __init__(self): num_sliders=0, num_counters=2, ) - self.counters[0].label.setText("Spot sigma") + self.counters[0].label.setText("Spot sigma") # closeness self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -531,7 +545,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { + VoronoiOtsu().name: VoronoiOtsu, Watershed().name: Watershed, ConnectedComponents().name: ConnectedComponents, - VoronoiOtsu().name: VoronoiOtsu, } From d6a3f430c0febb2a764843d280c30b17ef86c5e7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 09:50:45 +0100 Subject: [PATCH 043/577] Docs + UI update - Updated welcome/README - Changed step for DoubleCounter --- README.md | 5 +++-- docs/res/welcome.rst | 15 +++++++++------ napari_cellseg3d/interface.py | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 0654ad41..415a4f3d 100644 --- a/README.md +++ b/README.md @@ -129,8 +129,9 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). - +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). +Please refer to the documentation for full acknowledgements. ## Plugin base This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 6832e71e..d2f2c0f0 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -90,20 +90,23 @@ We also provide a model that was trained in-house on mesoSPIM nuclei data in col This plugin mainly uses the following libraries and software: -* `napari website`_ +* `napari`_ -* `PyTorch website`_ +* `PyTorch`_ -* `MONAI project website`_ (various models used here are credited `on their website`_) +* `MONAI project`_ (various models used here are credited `on their website`_) + +* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ .. _Wyss Center: https://wysscenter.ch/ .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP -.. _napari website: https://napari.org/ -.. _PyTorch website: https://pytorch.org/ -.. _MONAI project website: https://monai.io/ +.. _napari: https://napari.org/ +.. _PyTorch: https://pytorch.org/ +.. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets +.. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype .. rubric:: References diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d2f8d787..136da3e1 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1053,9 +1053,9 @@ def __init__( def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.1) + self.setSingleStep(0.01) else: - self.setSingleStep(1) + self.setSingleStep(0.1) @property def tooltips(self): From acfa281d2374dc6e9508579dacd464289a9bef6b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:07:33 +0100 Subject: [PATCH 044/577] Update requirements.txt Fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 739b7aa3..ead0052c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pyclesperanto-prototype>=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 2c11b311e6ae53f18b3054f9bb2aea71fcd573cc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:11:22 +0100 Subject: [PATCH 045/577] Update setup.cfg --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index e12775ca..eb7b30a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ install_requires = itk tqdm nibabel + pyclesperanto-prototype scikit-image pillow tqdm From 5ea88a5ff09fe3c08fc7c88f4abbab9e3529740e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 046/577] isort --- napari_cellseg3d/code_models/model_instance_seg.py | 12 ++---------- napari_cellseg3d/code_plugins/plugin_convert.py | 9 ++------- .../code_plugins/plugin_model_inference.py | 8 ++++---- napari_cellseg3d/config.py | 11 +++++------ napari_cellseg3d/interface.py | 3 ++- 5 files changed, 15 insertions(+), 28 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 7fd33317..7a5f097b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,30 +1,22 @@ from __future__ import division from __future__ import print_function - from dataclasses import dataclass from typing import List - import numpy as np - import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget - from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - -from skimage.filters import thresholding -from skimage.transform import resize - +from tifffile import imread # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes -from tifffile import imread from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index f461b46f..6432d761 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,22 +1,17 @@ import warnings from pathlib import Path - import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QWidget from tifffile import imread from tifffile import imwrite import napari_cellseg3d.interface as ui -from napari_cellseg3d import config from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - clear_small_objects, -) +from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 4a7ab671..2420829e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,13 +10,13 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import ( INSTANCE_SEGMENTATION_METHOD_LIST, ) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 74cbf81d..e665d28c 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,6 +8,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -16,12 +21,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 136da3e1..a854905b 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -12,7 +12,8 @@ # from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QObject from qtpy.QtCore import Qt -# from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices From 9c5e4bdad24f70bb73da3c5627ffe8191d39c9c7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 047/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 napari_cellseg3d/_tests/pytest.ini diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini new file mode 100644 index 00000000..814cca2e --- /dev/null +++ b/napari_cellseg3d/_tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +qt_api=pyqt5 \ No newline at end of file From f92d2bbfaccbb1fd90355e93cb549a6aab9db3bb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 048/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- .../code_models/model_instance_seg.py | 84 +++++++++++-------- .../code_plugins/plugin_convert.py | 2 +- .../code_plugins/plugin_model_inference.py | 2 +- 3 files changed, 49 insertions(+), 39 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 7a5f097b..57065971 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,11 +15,16 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from napari_cellseg3d.utils import sphericity_volume_area +# list of methods : +WATERSHED = "Watershed" +CONNECTED_COMP = "Connected Components" +VORONOI_OTSU = "Voronoi-Otsu" + class InstanceMethod: def __init__( @@ -28,6 +33,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, + widget_parent: QWidget = None ): """ Methods for instance segmentation @@ -36,6 +42,7 @@ def __init__( function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + widget_parent: parent for the declared widgets """ self.name = name self.function = function @@ -47,7 +54,7 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), ) self.sliders.append(getattr(self, widget)) @@ -57,7 +64,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label=""), + ui.DoubleIncrementCounter(label="", parent=None), ) self.counters.append(getattr(self, widget)) @@ -351,15 +358,16 @@ def fill(lst, n=len(properties) - 1): ) -class Watershed(InstanceMethod, metaclass=Singleton): +class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Watershed", + name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -396,15 +404,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): +class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Connected Components", + name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -426,15 +435,16 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): +class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self): + def __init__(self, widget_parent): super().__init__( - name="Voronoi-Otsu", + name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, num_counters=2, + widget_parent=widget_parent ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ @@ -479,7 +489,6 @@ def __init__(self, parent=None): parent: parent widget """ super().__init__(parent) - self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) @@ -490,37 +499,38 @@ def __init__(self, parent=None): self._build() def _build(self): - group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - self.instance_widgets[name] = [] - if len(method().sliders) > 0: - for slider in method().sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method().counters) > 0: - for counter in method().counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) + try: + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + method_class = method(widget_parent=self.parent()) + self.instance_widgets[name] = [] + # moderately unsafe way to init those widgets + if len(method_class.sliders) > 0: + for slider in method_class.sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method_class.counters) > 0: + for counter in method_class.counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + except RuntimeError as e: + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() - - for widget in self.instance_widgets[method.name]: - widget.set_visibility(True) - for key in self.instance_widgets.keys(): - if key != method.name: - for widget in self.instance_widgets[key]: + for name in self.instance_widgets.keys(): + if name != self.method_choice.currentText(): + for widget in self.instance_widgets[name]: widget.set_visibility(False) + else: + for widget in self.instance_widgets[name]: + widget.set_visibility(True) def run_method(self, volume): """ @@ -537,7 +547,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VoronoiOtsu().name: VoronoiOtsu, - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6432d761..7a59dcf0 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -376,7 +376,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.data_panel = self._build_io_panel() self.label_layer_loader.set_layer_type(napari.layers.Layer) - self.instance_widgets = InstanceWidgets() + self.instance_widgets = InstanceWidgets(parent=self) self.start_btn = ui.Button("Start", self._start) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 2420829e..5ad8fc3e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -191,7 +191,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## ################## # instance segmentation widgets - self.instance_widgets = InstanceWidgets(self) + self.instance_widgets = InstanceWidgets(parent=self) self.use_instance_choice = ui.CheckBox( "Run instance segmentation", func=self._toggle_display_instance From 462a0cc5db249ea2475d17c5f4e8886c7b67032e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 049/577] Fix inference --- .../code_models/model_instance_seg.py | 5 +- napari_cellseg3d/code_models/model_workers.py | 12 ++--- .../code_plugins/plugin_model_inference.py | 13 ++--- napari_cellseg3d/config.py | 6 ++- notebooks/assess_instance.ipynb | 50 +++++++++++++++++++ 5 files changed, 71 insertions(+), 15 deletions(-) create mode 100644 notebooks/assess_instance.ipynb diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 57065971..667b8bc3 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -492,8 +492,10 @@ def __init__(self, parent=None): self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = [] + self.methods = {} + """Contains the instance of the method, with its name as key""" self.instance_widgets = {} + """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() @@ -505,6 +507,7 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class self.instance_widgets[name] = [] # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index a43077e4..7fd15c72 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -309,9 +309,7 @@ def log_parameters(self): instance_config = config.post_process_config.instance if instance_config.enabled: self.log( - f"Instance segmentation enabled, method : {instance_config.method}\n" - f"Probability threshold is {instance_config.threshold.threshold_value:.2f}\n" - f"Objects smaller than {instance_config.small_object_removal_threshold.threshold_value} pixels will be removed\n" + f"Instance segmentation enabled, method : {instance_config.method.name}\n" ) self.log("-" * 20) @@ -383,7 +381,7 @@ def load_folder(self): return inference_loader def load_layer(self): - self.log("Loading layer\n") + self.log("\nLoading layer\n") data = np.squeeze(self.config.layer) volume = np.array(data, dtype=np.int16) @@ -544,7 +542,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -597,8 +595,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 5ad8fc3e..a6a90eb4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -553,9 +553,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -723,13 +724,13 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method = self.worker_config.post_process_config.instance.method + method_name = self.worker_config.post_process_config.instance.method.name number_cells = ( np.unique(labels.flatten()).size - 1 ) # remove background - name = f"({number_cells} objects)_{method}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" viewer.add_labels(labels, name=name) @@ -743,7 +744,7 @@ def on_yield(self, result: InferenceResult): f"Number of instances : {stats.number_objects}" ) - csv_name = f"/{method}_seg_results_{image_id}_{utils.get_date_time()}.csv" + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" stats_df.to_csv( self.worker_config.results_path + csv_name, index=False, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index e665d28c..107af8e6 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -119,12 +119,16 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb new file mode 100644 index 00000000..40412282 --- /dev/null +++ b/notebooks/assess_instance.ipynb @@ -0,0 +1,50 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tifffile import imread\n", + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From de02fcc903ce6c3773270396d5f92fb9ef7a6e39 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 050/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/model_instance_seg.py | 9 +- napari_cellseg3d/code_plugins/plugin_crop.py | 4 +- .../code_plugins/plugin_utilities.py | 7 +- .../dev_scripts/artefact_labeling.py | 421 ++++++++++++++++++ .../dev_scripts/correct_labels.py | 320 +++++++++++++ .../dev_scripts/evaluate_labels.py | 276 ++++++++++++ notebooks/assess_instance.ipynb | 401 ++++++++++++++++- 7 files changed, 1420 insertions(+), 18 deletions(-) create mode 100644 napari_cellseg3d/dev_scripts/artefact_labeling.py create mode 100644 napari_cellseg3d/dev_scripts/correct_labels.py create mode 100644 napari_cellseg3d/dev_scripts/evaluate_labels.py diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 667b8bc3..a8bb240b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -37,12 +37,14 @@ def __init__( ): """ Methods for instance segmentation + Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -118,14 +120,15 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - Returns: Instance segmentation labels from Voronoi-Otsu method + """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) @@ -191,6 +194,7 @@ def binary_watershed( thres_seeding (float): threshold for seeding. Default: 0.9 thres_small (int): size threshold of small objects removal. Default: 10 rem_seed_thres (int): threshold for small seeds removal. Default : 3 + """ semantic = np.squeeze(volume) seed_map = semantic > thres_seeding @@ -487,6 +491,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -538,10 +543,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = INSTANCE_SEGMENTATION_METHOD_LIST[ self.method_choice.currentText() diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 5f502978..9f4d80b6 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -177,8 +177,8 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 400]) - self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Expanding) + ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() # def _check_results_path(self, folder): diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 9e66213f..1f0d598b 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -60,10 +60,10 @@ def _build(self): layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) - layout.setSizeConstraint(QLayout.SetFixedSize) + # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) - self.setMinimumHeight(1000) - self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) + # self.setMinimumHeight(2000) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._update_visibility() def _create_utils_widgets(self, names): @@ -79,7 +79,6 @@ def _create_utils_widgets(self, names): raise RuntimeError( "One or several utility widgets are missing/erroneous" ) - # TODO how to auto-update list based on UTILITIES_WIDGETS ? def _update_visibility(self): widget_class = UTILITIES_WIDGETS[self.utils_choice.currentText()] diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py new file mode 100644 index 00000000..875ca9b6 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -0,0 +1,421 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +from pathlib import Path +import scipy.ndimage as ndimage +import os +import napari +# import sys +# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from skimage.filters import threshold_otsu + +""" +New code by Yves Paychere +Creates labels of artifacts in an image based on existing labels of neurons +""" + + +def map_labels(labels, artefacts): + """Map the artefacts labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + artefacts : ndarray + Label image with artefacts labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the artefact and the label value of the neurone associated or the neurons associated + new_labels: list + The labels of the artefacts that are not labelled in the neurons + """ + map_labels_existing = [] + new_labels = [] + + for i in np.unique(artefacts): + if i == 0: + continue + indexes = labels[artefacts == i] + # find the most common label in the indexes + unique, counts = np.unique(indexes, return_counts=True) + unique = np.flip(unique[np.argsort(counts)]) + counts = np.flip(counts[np.argsort(counts)]) + if unique[0] != 0: + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + elif ( + counts[0] < np.sum(counts) * 2 / 3.0 + ): # the artefact is connected to multiple neurons + total = 0 + ii = 1 + while total < np.size(indexes) / 3.0: + total = np.sum(counts[1 : ii + 1]) + ii += 1 + map_labels_existing.append(np.append([i], unique[1 : ii + 1])) + else: + new_labels.append(i) + + return map_labels_existing, new_labels + + +def make_labels( + path_image, + path_labels_out, + threshold_factor=1, + threshold_size=30, + label_value=1, + do_multi_label=True, + use_watershed=True, + augment_contrast_factor=2, +): + """Detect nucleus. using a binary watershed algorithm and otsu thresholding. + Parameters + ---------- + path_image : str + Path to image. + path_labels_out : str + Path of the output labelled image. + threshold_size : int, optional + Threshold for nucleus size, if the nucleus is smaller than this value it will be removed. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different nucleus will be labelled as a different value. + use_watershed : bool, optional + If True, use watershed algorithm to detect nucleus. + augment_contrast_factor : int, optional + Factor to augment the contrast of the image. + Returns + ------- + ndarray + Label image with nucleus labelled with 1 value per nucleus. + """ + + image = imread(path_image) + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + threshold_brightness = threshold_otsu(image) * threshold_factor + image_contrasted = np.where(image > threshold_brightness, image, 0) + + if use_watershed: + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = image_contrasted * augment_contrast_factor + image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) + labels = binary_watershed(image_contrasted, thres_small=threshold_size) + else: + labels = ndimage.label(image_contrasted)[0] + + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + + if not do_multi_label: + labels = np.where(labels > 0, label_value, 0) + + imwrite(path_labels_out, labels.astype(np.uint16)) + imwrite( + path_labels_out.replace(".tif", "_contrast.tif"), + image_contrasted.astype(np.float32), + ) + + +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): + """Select image by labels. + Parameters + ---------- + path_image : str + Path to image. + path_labels : str + Path to labels. + path_image_out : str + Path of the output image. + label_values : list + List of label values to select. + """ + image = imread(path_image) + labels = imread(path_labels) + image = np.where(np.isin(labels, label_values), image, 0) + imwrite(path_image_out, image.astype(np.float32)) + + +# select the smalles cube that contains all the none zero pixel of an 3d image +def get_bounding_box(img): + height = np.any(img, axis=(0, 1)) + rows = np.any(img, axis=(0, 2)) + cols = np.any(img, axis=(1, 2)) + + xmin, xmax = np.where(cols)[0][[0, -1]] + ymin, ymax = np.where(rows)[0][[0, -1]] + zmin, zmax = np.where(height)[0][[0, -1]] + return xmin, xmax, ymin, ymax, zmin, zmax + + +# crop the image +def crop_image(img): + xmin, xmax, ymin, ymax, zmin, zmax = get_bounding_box(img) + return img[xmin:xmax, ymin:ymax, zmin:zmax] + + +def crop_image_path(path_image, path_image_out): + """Crop image. + Parameters + ---------- + path_image : str + Path to image. + path_image_out : str + Path of the output image. + """ + image = imread(path_image) + image = crop_image(image) + imwrite(path_image_out, image.astype(np.float32)) + + +def make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, + label_value=2, + do_multi_label=False, + remove_true_labels=True, +): + """Detect pseudo nucleus. + Parameters + ---------- + image : ndarray + Image. + labels : ndarray + Label image. + threshold_artefact_brightness_percent : int, optional + Threshold for artefact brightness. + threshold_artefact_size_percent : int, optional + Threshold for artefact size, if the artefcact is smaller than this percentage of the neurons it will be removed. + contrast_power : int, optional + Power for contrast enhancement. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different artefact will be labelled as a different value. + remove_true_labels : bool, optional + If True, the true labels will be removed from the artefacts. + Returns + ------- + ndarray + Label image with pseudo nucleus labelled with 1 value per artefact. + """ + + neurons = np.array(labels > 0) + non_neurons = np.array(labels == 0) + + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + # calculate the percentile of the intensity of all the pixels that are labeled as neurons + # check if the neurons are not empty + if np.sum(neurons) > 0: + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + else: + # take the percentile of the non neurons if the neurons are empty + threshold = np.percentile(image[non_neurons], 90) + + # modify the contrast of the image accoring to the threshold with a tanh function and map the values to [0,1] + + image_contrasted = np.tanh((image - threshold) * contrast_power) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) + + artefacts = binary_watershed( + image_contrasted, thres_seeding=0.95, thres_small=15, thres_objects=0.4 + ) + + if remove_true_labels: + # evaluate where the artefacts are connected to the neurons + # map the artefacts label to the neurons label + map_labels_existing, new_labels = map_labels(labels, artefacts) + + # remove the artefacts that are connected to the neurons + for i in map_labels_existing: + artefacts[artefacts == i[0]] = 0 + # remove all the pixels of the neurons from the artefacts + artefacts = np.where(labels > 0, 0, artefacts) + + # remove the artefacts that are too small + # calculate the percentile of the size of the neurons + if np.sum(neurons) > 0: + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + else: + # find the size of each connected component + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + # remove the smallest connected components + neurone_size_percentile = np.percentile(sizes, 95) + + # select the artefacts that are bigger than the percentile + + artefacts = select_artefacts_by_size( + artefacts, min_size=neurone_size_percentile, is_labeled=True + ) + + # relabel with the label value if the artefacts are not multi label + if not do_multi_label: + artefacts = np.where(artefacts > 0, label_value, artefacts) + + return artefacts + + +def select_artefacts_by_size(artefacts, min_size, is_labeled=False): + """Select artefacts by size. + Parameters + ---------- + artefacts : ndarray + Label image with artefacts labelled as 1. + min_size : int, optional + Minimum size of artefacts to keep + is_labeled : bool, optional + If True, the artefacts are already labelled. + Returns + ------- + ndarray + Label image with artefacts labelled and small artefacts removed. + """ + if not is_labeled: + # find all the connected components in the artefacts image + labels = ndimage.label(artefacts)[0] + else: + labels = artefacts + + # remove the small components + labels_i, counts = np.unique(labels, return_counts=True) + labels_i = labels_i[counts > min_size] + labels_i = labels_i[labels_i > 0] + artefacts = np.where(np.isin(labels, labels_i), labels, 0) + return artefacts + + +def create_artefact_labels( + image_path, + labels_path, + output_path, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. + Parameters + ---------- + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. + output_path : str + Path to save the output label image file. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + image = imread(image_path) + labels = imread(labels_path) + + artefacts = make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power=contrast_power, + label_value=2, + do_multi_label=False, + ) + + neurons_artefacts_labels = np.where(labels > 0, 1, artefacts) + imwrite(output_path, neurons_artefacts_labels) + + +def visualize_images(paths): + """Visualize images. + Parameters + ---------- + paths : list + List of paths to images to visualize. + """ + viewer = napari.Viewer(ndisplay=3) + for path in paths: + viewer.add_image(imread(path), name=os.path.basename(path)) + # wait for the user to close the viewer + napari.run() + + +def create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1 for all images in a folder. The images created are stored in a folder artefact_neurons. + Parameters + ---------- + path : str + Path to folder with images in folder volumes and labels in folder lab_sem. The images are expected to have the same alphabetical order in both folders. + do_visualize : bool, optional + If True, the images will be visualized. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + # find all the images in the folder and create a list + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + # sort the list + path_labels.sort() + path_images.sort() + # create the output folder + os.makedirs(path + "/artefact_neurons", exist_ok=True) + # create the artefact labels + for i in range(len(path_images)): + print(path_labels[i]) + # consider that the images and the labels have names in the same alphabetical order + create_artefact_labels( + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power, + ) + if do_visualize: + visualize_images( + [ + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + ] + ) + + +if __name__ == "__main__": + + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py new file mode 100644 index 00000000..f94327e2 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -0,0 +1,320 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings +from napari.qt.threading import thread_worker +from tqdm import tqdm +import threading +# import sys +# sys.path.append(str(Path(__file__) / "../../")) + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +""" +New code by Yves Paychère +Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold +""" + + +def relabel_non_unique_i(label, save_path, go_fast=False): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + label : np.array + the label image + save_path : str + the path to save the relabeld image + """ + value_label = 0 + new_labels = np.zeros_like(label) + map_labels_existing = [] + unique_label = np.unique(label) + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + i = unique_label[i_label] + if i == 0: + continue + if go_fast: + new_label, to_add = ndimage.label(label == i) + map_labels_existing.append( + [i, list(range(value_label + 1, value_label + to_add + 1))] + ) + + else: + # catch the warning of the watershed + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_label = binary_watershed(label == i) + unique = np.unique(new_label) + to_add = unique[-1] + map_labels_existing.append([i, unique[1:] + value_label]) + + new_label[new_label != 0] += value_label + new_labels += new_label + value_label += to_add + + imwrite(save_path, new_labels) + return map_labels_existing + + +def add_label(old_label, artefact, new_label_path, i_labels_to_add): + """add the label to the label image + Parameters + ---------- + old_label : np.array + the label image + artefact : np.array + the artefact image that contains some neurons + new_label_path : str + the path to save the new label image + """ + new_label = old_label.copy() + max_label = np.max(old_label) + for i, i_label in enumerate(i_labels_to_add): + new_label[artefact == i_label] = i + max_label + 1 + imwrite(new_label_path, new_label) + + +returns = [] + + +def ask_labels(unique_artefact): + global returns + returns = [] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + + if i_labels_to_add_tmp == [0]: + print("no label added") + returns = [[]] + print("close the napari window to continue") + return + + for i in i_labels_to_add_tmp: + if i == 0: + print("0 is not a valid label") + # delete the 0 + i_labels_to_add_tmp.remove(i) + # test if all index are negative + if all(i < 0 for i in i_labels_to_add_tmp): + print( + "all labels are negative-> will add all the labels except the one you gave" + ) + i_labels_to_add = list(unique_artefact) + for i in i_labels_to_add_tmp: + if np.abs(i) in i_labels_to_add: + i_labels_to_add.remove(np.abs(i)) + else: + print("the label", np.abs(i), "is not in the label image") + i_labels_to_add_tmp = i_labels_to_add + else: + # remove the negative index + for i in i_labels_to_add_tmp: + if i < 0: + i_labels_to_add_tmp.remove(i) + print( + "ignore the negative label", + i, + " since not all the labels are negative", + ) + if i not in unique_artefact: + print("the label", i, "is not in the label image") + i_labels_to_add_tmp.remove(i) + + returns = [i_labels_to_add_tmp] + print("close the napari window to continue") + + +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + image_path : str + the path to the image + label_path : str + the path to the label image + go_fast : bool, optional + if True, the relabeling will be faster but the labels can more frequently be merged, by default False + check_for_unicity : bool, optional + if True, the relabeling will check if the labels are unique, by default True + delay : float, optional + the delay between each image for the visualization, by default 0.3 + """ + global returns + + label = imread(label_path) + initial_label_path = label_path + if check_for_unicity: + # check if the label are unique + new_label_path = label_path[:-4] + "_relabel_unique.tif" + map_labels_existing = relabel_non_unique_i( + label, new_label_path, go_fast=go_fast + ) + print( + "visualize the relabeld image in white the previous labels and in red the new labels" + ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + label_path = new_label_path + # detect artefact + print("detection of potential neurons (in progress)") + image = imread(image_path) + artefact = make_artefact_labels.make_artefact_labels( + image, + imread(label_path), + do_multi_label=True, + threshold_artefact_brightness_percent=30, + threshold_artefact_size_percent=0, + contrast_power=30, + ) + print("detection of potential neurons (done)") + # ask the user if the artefact are not neurons + i_labels_to_add = [] + loop = True + unique_artefact = list(np.unique(artefact)) + while loop: + # visualize the artefact and ask the user which label to add to the label image + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t.start() + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="potential neurons") + viewer.add_labels(imread(label_path), name="labels") + napari.run() + t.join() + i_labels_to_add_tmp = returns[0] + # check if the selected labels are neurones + for i in i_labels_to_add: + if i not in i_labels_to_add_tmp: + i_labels_to_add_tmp.append(i) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + print("these labels will be added") + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if revert != "y": + i_labels_to_add = i_labels_to_add_tmp + for i in i_labels_to_add: + if i in unique_artefact: + unique_artefact.remove(i) + loop = input("Do you want to add more labels? (y/n)") == "y" + # add the label to the label image + new_label_path = initial_label_path[:-4] + "_new_label.tif" + print("the new label will be saved in", new_label_path) + add_label(imread(label_path), artefact, new_label_path, i_labels_to_add) + # store the artefact remaining + new_artefact_path = initial_label_path[:-4] + "_artefact.tif" + artefact = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + imwrite(new_artefact_path, artefact) + + +def modify_viewer(old_label, new_label, args): + """modify the viewer to show the relabeling + Parameters + ---------- + old_label : napari.layers.Labels + the layer of the old label + new_label : napari.layers.Labels + the layer of the new label + args : list + the first element is the old label and the second element is the new label + """ + if args == "hide new label": + new_label.visible = False + elif args == "show new label": + new_label.visible = True + else: + old_label.selected_label = args[0] + if not np.isnan(args[1]): + new_label.selected_label = args[1] + + +@thread_worker +def to_show(map_labels_existing, delay=0.5): + """modify the viewer to show the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the of the map between the old label and the new label + delay : float, optional + the delay between each image for the visualization, by default 0.3 + """ + time.sleep(2) + for i in map_labels_existing: + yield "hide new label" + if len(i[1]): + yield [i[0], i[1][0]] + else: + yield [i[0], np.nan] + time.sleep(delay) + yield "show new label" + for j in i[1]: + yield [i[0], j] + time.sleep(delay) + + +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): + """Builds a widget that can control a function in another thread.""" + + worker = to_show(map_labels_existing, delay) + worker.start() + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + + +def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): + """visualize the map of the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the relabeling + """ + label = imread(label_path) + relabel = imread(relabel_path) + + viewer = napari.Viewer(ndisplay=3) + + old_label = viewer.add_labels(label, num_colors=3) + new_label = viewer.add_labels(relabel, num_colors=3) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + + # viewer.dims.ndisplay = 3 + viewer.camera.angles = (180, 3, 50) + viewer.camera.zoom = 1 + + old_label.show_selected_label = True + new_label.show_selected_label = True + + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + napari.run() + + +def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + folder_path : str + the path to the folder containing the label images + end_of_new_name : str + thename to add at the end of the relabled image + """ + for file in Path.iterdir(folder_path): + if file.suffix == ".tif": + label = imread(str(Path(folder_path / file))) + relabel_non_unique_i( + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + ) + + +if __name__ == "__main__": + + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py new file mode 100644 index 00000000..857bcd19 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -0,0 +1,276 @@ +import numpy as np +import pandas as pd +from tqdm import tqdm +import napari + +from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels + + +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): + """Evaluate the model performance. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + do_print : bool + If True, print the results. + Returns + ------- + neuron_found : float + The number of neurons found by the model + neuron_fused: float + The number of neurons fused by the model + neuron_not_found: float + The number of neurons not found by the model + neuron_artefact: float + The number of artefact that the model wrongly labelled as neurons + mean_true_positive_ratio_model: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the model's label) + mean_ratio_pixel_found: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_ratio_pixel_found_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_true_positive_ratio_model_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels in any fused neurons of this model's label)/(total number of pixels of the model's label) + mean_ratio_false_pixel_artefact: float + The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) + """ + log.debug("Mapping labels...") + map_labels_existing, map_fused_neurons, new_labels = map_labels( + labels, model_labels + ) + + # calculate the number of neurons individually found + neurons_found = len(map_labels_existing) + # calculate the number of neurons fused + neurons_fused = len(map_fused_neurons) + # calculate the number of neurons not found + log.debug("Calculating the number of neurons not found...") + neurons_found_labels = np.unique( + [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] + ) + unique_labels = np.unique(labels) + neurons_not_found = len(unique_labels) - 1 - len(neurons_found_labels) + # artefacts found + artefacts_found = len(new_labels) + if len(map_labels_existing) > 0: + # calculate the mean true positive ratio of the model + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + # calculate the mean ratio of the neurons pixels correctly labelled + mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) + else: + mean_true_positive_ratio_model = np.nan + mean_ratio_pixel_found = np.nan + + if len(map_fused_neurons) > 0: + # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + # calculate the mean true positive ratio of the model for the fused neurons + mean_true_positive_ratio_model_fused = np.mean( + [i[3] for i in map_fused_neurons] + ) + else: + mean_ratio_pixel_found_fused = np.nan + mean_true_positive_ratio_model_fused = np.nan + + # calculate the mean false positive ratio of each artefact + if len(new_labels) > 0: + mean_ratio_false_pixel_artefact = np.mean([i[1] for i in new_labels]) + else: + mean_ratio_false_pixel_artefact = np.nan + + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( + "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, + ) + print( + "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, + ) + print( + "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, + ) + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + ) + if visualize: + viewer = napari.Viewer() + viewer.add_labels(labels, name="ground truth") + viewer.add_labels(model_labels, name="model's labels") + found_model = np.where( + np.isin(model_labels, [i[0] for i in map_labels_existing]), + model_labels, + 0, + ) + viewer.add_labels(found_model, name="model's labels found") + found_label = np.where( + np.isin(labels, [i[1] for i in map_labels_existing]), labels, 0 + ) + viewer.add_labels(found_label, name="ground truth found") + neurones_not_found_labels = np.where( + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + ) + neurones_not_found_labels = neurones_not_found_labels[ + neurones_not_found_labels != 0 + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + viewer.add_labels(not_found, name="ground truth not found") + artefacts_found = np.where( + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + ) + viewer.add_labels(artefacts_found, name="model's labels artefacts") + fused_model = np.where( + np.isin(model_labels, [i[0] for i in map_fused_neurons]), + model_labels, + 0, + ) + viewer.add_labels(fused_model, name="model's labels fused") + fused_label = np.where( + np.isin(labels, [i[1] for i in map_fused_neurons]), labels, 0 + ) + viewer.add_labels(fused_label, name="ground truth fused") + napari.run() + + return ( + neurons_found, + neurons_fused, + neurons_not_found, + artefacts_found, + mean_true_positive_ratio_model, + mean_ratio_pixel_found, + mean_ratio_pixel_found_fused, + mean_true_positive_ratio_model_fused, + mean_ratio_false_pixel_artefact, + ) + + +def save_as_csv(results, path): + """ + Save the results as a csv file + + Parameters + ---------- + results: list + The results of the evaluation + path: str + The path to save the csv file + """ + print(np.array(results).shape) + df = pd.DataFrame( + [results], + columns=[ + "neurons_found", + "neurons_fused", + "neurons_not_found", + "artefacts_found", + "mean_true_positive_ratio_model", + "mean_ratio_pixel_found", + "mean_ratio_pixel_found_fused", + "mean_true_positive_ratio_model_fused", + "mean_ratio_false_pixel_artefact", + ], + ) + df.to_csv(path, index=False) + + +# if __name__ == "__main__": +# """ +# # Example of how to use the functions in this module. +# a = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# b = np.array([[5, 5, 0, 0], [5, 5, 2, 0], [0, 2, 2, 0], [0, 0, 2, 0]]) +# evaluate_model_performance(a, b) +# +# c = np.array([[2, 2, 0, 0], [2, 2, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# d = np.array([[4, 0, 4, 0], [4, 4, 4, 0], [0, 4, 4, 0], [0, 0, 4, 0]]) +# +# evaluate_model_performance(c, d) +# +# from tifffile import imread +# labels=imread("dataset/visual_tif/labels/testing_im_new_label.tif") +# labels_model=imread("dataset/visual_tif/artefact_neurones/basic_model.tif") +# evaluate_model_performance(labels, labels_model,visualize=True) +# """ +# from tifffile import imread +# +# labels = imread("dataset_clean/VALIDATION/validation_labels.tif") +# try: +# labels_model = imread("results/watershed_based_model/instance_labels.tif") +# except: +# raise Exception( +# "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" +# ) +# +# evaluate_model_performance(labels, labels_model, visualize=True) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From d7ffb8bc8abf6c7aaf864820bd9542cdb2e9c22b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 051/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../dev_scripts/evaluate_labels.py | 22 +- notebooks/assess_instance.ipynb | 408 ++++++++++++------ 2 files changed, 301 insertions(+), 129 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..b4436ccb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -4,6 +4,7 @@ import napari from napari_cellseg3d.utils import LOGGER as log + def map_labels(labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -33,10 +34,12 @@ def map_labels(labels, model_labels): unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + print(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + print(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -50,8 +53,7 @@ def map_labels(labels, model_labels): tmp_map.append( [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,12 +61,14 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # print(f"map_labels_existing: {map_labels_existing}") + print(f"map_fused_neurons: {map_fused_neurons}") + # print(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels @@ -99,7 +103,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -109,7 +113,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..6e6a9b5f 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -111,17 +111,274 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "i: 1\n", + "unique: 1\n", + "i: 2\n", + "unique: 2\n", + "i: 3\n", + "unique: 3\n", + "i: 4\n", + "unique: 4\n", + "i: 5\n", + "unique: 5\n", + "i: 6\n", + "unique: 6\n", + "i: 7\n", + "unique: 7\n", + "i: 8\n", + "unique: 8\n", + "i: 9\n", + "unique: 9\n", + "i: 10\n", + "unique: 10\n", + "i: 11\n", + "unique: 11\n", + "i: 12\n", + "unique: 12\n", + "i: 13\n", + "unique: 13\n", + "i: 14\n", + "unique: 14\n", + "i: 15\n", + "unique: 15\n", + "i: 16\n", + "unique: 16\n", + "i: 17\n", + "unique: 17\n", + "i: 18\n", + "unique: 18\n", + "i: 19\n", + "unique: 19\n", + "i: 20\n", + "unique: 20\n", + "i: 21\n", + "unique: 21\n", + "i: 22\n", + "unique: 22\n", + "i: 23\n", + "unique: 23\n", + "i: 24\n", + "unique: 24\n", + "i: 25\n", + "unique: 25\n", + "i: 26\n", + "unique: 26\n", + "i: 27\n", + "unique: 27\n", + "i: 28\n", + "unique: 28\n", + "i: 29\n", + "unique: 29\n", + "i: 30\n", + "unique: 30\n", + "i: 31\n", + "unique: 31\n", + "i: 32\n", + "unique: 32\n", + "i: 33\n", + "unique: 33\n", + "i: 34\n", + "unique: 34\n", + "i: 35\n", + "unique: 35\n", + "i: 36\n", + "unique: 36\n", + "i: 37\n", + "unique: 37\n", + "i: 38\n", + "unique: 38\n", + "i: 39\n", + "unique: 39\n", + "i: 40\n", + "unique: 40\n", + "i: 41\n", + "unique: 41\n", + "i: 42\n", + "unique: 42\n", + "i: 43\n", + "unique: 43\n", + "i: 44\n", + "unique: 44\n", + "i: 45\n", + "unique: 45\n", + "i: 46\n", + "unique: 46\n", + "i: 47\n", + "unique: 47\n", + "i: 48\n", + "unique: 48\n", + "i: 49\n", + "unique: 49\n", + "i: 50\n", + "unique: 50\n", + "i: 51\n", + "unique: 51\n", + "i: 52\n", + "unique: 52\n", + "i: 53\n", + "unique: 53\n", + "i: 54\n", + "unique: 54\n", + "i: 55\n", + "unique: 55\n", + "i: 56\n", + "unique: 56\n", + "i: 57\n", + "unique: 57\n", + "i: 58\n", + "unique: 58\n", + "i: 59\n", + "unique: 59\n", + "i: 60\n", + "unique: 60\n", + "i: 61\n", + "unique: 61\n", + "i: 62\n", + "unique: 62\n", + "i: 63\n", + "unique: 63\n", + "i: 64\n", + "unique: 64\n", + "i: 65\n", + "unique: 65\n", + "i: 66\n", + "unique: 66\n", + "i: 67\n", + "unique: 67\n", + "i: 68\n", + "unique: 68\n", + "i: 69\n", + "unique: 69\n", + "i: 70\n", + "unique: 70\n", + "i: 71\n", + "unique: 71\n", + "i: 72\n", + "unique: 72\n", + "i: 73\n", + "unique: 73\n", + "i: 74\n", + "unique: 74\n", + "i: 75\n", + "unique: 75\n", + "i: 76\n", + "unique: 76\n", + "i: 77\n", + "unique: 77\n", + "i: 78\n", + "unique: 78\n", + "i: 79\n", + "unique: 79\n", + "i: 80\n", + "unique: 80\n", + "i: 81\n", + "unique: 81\n", + "i: 82\n", + "unique: 82\n", + "i: 83\n", + "unique: 83\n", + "i: 84\n", + "unique: 84\n", + "i: 85\n", + "unique: 85\n", + "i: 86\n", + "unique: 86\n", + "i: 87\n", + "unique: 87\n", + "i: 88\n", + "unique: 88\n", + "i: 89\n", + "unique: 89\n", + "i: 90\n", + "unique: 90\n", + "i: 91\n", + "unique: 91\n", + "i: 93\n", + "unique: 93\n", + "i: 94\n", + "unique: 94\n", + "i: 95\n", + "unique: 95\n", + "i: 96\n", + "unique: 96\n", + "i: 97\n", + "unique: 97\n", + "i: 98\n", + "unique: 98\n", + "i: 99\n", + "unique: 99\n", + "i: 100\n", + "unique: 100\n", + "i: 101\n", + "unique: 101\n", + "i: 102\n", + "unique: 102\n", + "i: 103\n", + "unique: 103\n", + "i: 104\n", + "unique: 104\n", + "i: 105\n", + "unique: 105\n", + "i: 106\n", + "unique: 106\n", + "i: 107\n", + "unique: 107\n", + "i: 108\n", + "unique: 108\n", + "i: 109\n", + "unique: 109\n", + "i: 110\n", + "unique: 110\n", + "i: 111\n", + "unique: 111\n", + "i: 112\n", + "unique: 112\n", + "i: 113\n", + "unique: 113\n", + "i: 114\n", + "unique: 114\n", + "i: 115\n", + "unique: 115\n", + "i: 116\n", + "unique: 116\n", + "i: 117\n", + "unique: 117\n", + "i: 118\n", + "unique: 118\n", + "i: 119\n", + "unique: 119\n", + "i: 120\n", + "unique: 120\n", + "i: 121\n", + "unique: 121\n", + "i: 122\n", + "unique: 122\n", + "i: 123\n", + "unique: 123\n", + "i: 124\n", + "unique: 124\n", + "i: 125\n", + "unique: 125\n", + "map_fused_neurons: []\n", + "Calculating the number of neurons not found...\n", "Neurons found: 124\n", "Neurons fused: 0\n", "Neurons not found: 0\n", @@ -157,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -168,145 +425,66 @@ { "data": { "text/plain": [ - "" + "dtype('int32')" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected,name='connected')\n", + "connected.dtype" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -320,7 +498,7 @@ { "data": { "text/plain": [ - "(25, 64, 64)" + "dtype('int64')" ] }, "execution_count": 9, @@ -329,14 +507,12 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" + "gt_labels_resized.dtype" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -353,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -374,15 +550,7 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" - ] - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] From ed46c5a907f5f2f69cc53e7f87ab8e719b17f8c9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 052/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../code_models/model_instance_seg.py | 2 +- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 +- .../dev_scripts/evaluate_labels.py | 282 +++++++-- napari_cellseg3d/utils.py | 2 +- notebooks/assess_instance.ipynb | 553 ++++++++---------- requirements.txt | 4 +- setup.cfg | 2 +- 8 files changed, 569 insertions(+), 354 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index a8bb240b..77e5c981 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -136,7 +136,7 @@ def voronoi_otsu( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return instance + return np.array(instance) def binary_connected( diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b4436ccb..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,15 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -22,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -29,17 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 - print(f"i: {i}") + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - print(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -47,14 +228,20 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if len(tmp_map) == 1: # map to only one true neuron -> found neuron map_labels_existing.append(tmp_map[0]) @@ -62,17 +249,21 @@ def map_labels(labels, model_labels): # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map - # print(f"map_labels_existing: {map_labels_existing}") - print(f"map_fused_neurons: {map_fused_neurons}") - # print(f"new_labels: {new_labels}") + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -82,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -103,7 +296,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -113,7 +306,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) @@ -123,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -132,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -148,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -183,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -230,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 7897fdf3..7f4f3dc1 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -134,7 +134,7 @@ def resize(image, zoom_factors): mode="nearest-exact", padding_mode="empty", )(np.expand_dims(image, axis=0)) - return isotropic_image[0] + return isotropic_image[0].numpy() def align_array_sizes(array_shape, target_shape): diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 6e6a9b5f..d521c395 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -115,279 +110,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "i: 1\n", - "unique: 1\n", - "i: 2\n", - "unique: 2\n", - "i: 3\n", - "unique: 3\n", - "i: 4\n", - "unique: 4\n", - "i: 5\n", - "unique: 5\n", - "i: 6\n", - "unique: 6\n", - "i: 7\n", - "unique: 7\n", - "i: 8\n", - "unique: 8\n", - "i: 9\n", - "unique: 9\n", - "i: 10\n", - "unique: 10\n", - "i: 11\n", - "unique: 11\n", - "i: 12\n", - "unique: 12\n", - "i: 13\n", - "unique: 13\n", - "i: 14\n", - "unique: 14\n", - "i: 15\n", - "unique: 15\n", - "i: 16\n", - "unique: 16\n", - "i: 17\n", - "unique: 17\n", - "i: 18\n", - "unique: 18\n", - "i: 19\n", - "unique: 19\n", - "i: 20\n", - "unique: 20\n", - "i: 21\n", - "unique: 21\n", - "i: 22\n", - "unique: 22\n", - "i: 23\n", - "unique: 23\n", - "i: 24\n", - "unique: 24\n", - "i: 25\n", - "unique: 25\n", - "i: 26\n", - "unique: 26\n", - "i: 27\n", - "unique: 27\n", - "i: 28\n", - "unique: 28\n", - "i: 29\n", - "unique: 29\n", - "i: 30\n", - "unique: 30\n", - "i: 31\n", - "unique: 31\n", - "i: 32\n", - "unique: 32\n", - "i: 33\n", - "unique: 33\n", - "i: 34\n", - "unique: 34\n", - "i: 35\n", - "unique: 35\n", - "i: 36\n", - "unique: 36\n", - "i: 37\n", - "unique: 37\n", - "i: 38\n", - "unique: 38\n", - "i: 39\n", - "unique: 39\n", - "i: 40\n", - "unique: 40\n", - "i: 41\n", - "unique: 41\n", - "i: 42\n", - "unique: 42\n", - "i: 43\n", - "unique: 43\n", - "i: 44\n", - "unique: 44\n", - "i: 45\n", - "unique: 45\n", - "i: 46\n", - "unique: 46\n", - "i: 47\n", - "unique: 47\n", - "i: 48\n", - "unique: 48\n", - "i: 49\n", - "unique: 49\n", - "i: 50\n", - "unique: 50\n", - "i: 51\n", - "unique: 51\n", - "i: 52\n", - "unique: 52\n", - "i: 53\n", - "unique: 53\n", - "i: 54\n", - "unique: 54\n", - "i: 55\n", - "unique: 55\n", - "i: 56\n", - "unique: 56\n", - "i: 57\n", - "unique: 57\n", - "i: 58\n", - "unique: 58\n", - "i: 59\n", - "unique: 59\n", - "i: 60\n", - "unique: 60\n", - "i: 61\n", - "unique: 61\n", - "i: 62\n", - "unique: 62\n", - "i: 63\n", - "unique: 63\n", - "i: 64\n", - "unique: 64\n", - "i: 65\n", - "unique: 65\n", - "i: 66\n", - "unique: 66\n", - "i: 67\n", - "unique: 67\n", - "i: 68\n", - "unique: 68\n", - "i: 69\n", - "unique: 69\n", - "i: 70\n", - "unique: 70\n", - "i: 71\n", - "unique: 71\n", - "i: 72\n", - "unique: 72\n", - "i: 73\n", - "unique: 73\n", - "i: 74\n", - "unique: 74\n", - "i: 75\n", - "unique: 75\n", - "i: 76\n", - "unique: 76\n", - "i: 77\n", - "unique: 77\n", - "i: 78\n", - "unique: 78\n", - "i: 79\n", - "unique: 79\n", - "i: 80\n", - "unique: 80\n", - "i: 81\n", - "unique: 81\n", - "i: 82\n", - "unique: 82\n", - "i: 83\n", - "unique: 83\n", - "i: 84\n", - "unique: 84\n", - "i: 85\n", - "unique: 85\n", - "i: 86\n", - "unique: 86\n", - "i: 87\n", - "unique: 87\n", - "i: 88\n", - "unique: 88\n", - "i: 89\n", - "unique: 89\n", - "i: 90\n", - "unique: 90\n", - "i: 91\n", - "unique: 91\n", - "i: 93\n", - "unique: 93\n", - "i: 94\n", - "unique: 94\n", - "i: 95\n", - "unique: 95\n", - "i: 96\n", - "unique: 96\n", - "i: 97\n", - "unique: 97\n", - "i: 98\n", - "unique: 98\n", - "i: 99\n", - "unique: 99\n", - "i: 100\n", - "unique: 100\n", - "i: 101\n", - "unique: 101\n", - "i: 102\n", - "unique: 102\n", - "i: 103\n", - "unique: 103\n", - "i: 104\n", - "unique: 104\n", - "i: 105\n", - "unique: 105\n", - "i: 106\n", - "unique: 106\n", - "i: 107\n", - "unique: 107\n", - "i: 108\n", - "unique: 108\n", - "i: 109\n", - "unique: 109\n", - "i: 110\n", - "unique: 110\n", - "i: 111\n", - "unique: 111\n", - "i: 112\n", - "unique: 112\n", - "i: 113\n", - "unique: 113\n", - "i: 114\n", - "unique: 114\n", - "i: 115\n", - "unique: 115\n", - "i: 116\n", - "unique: 116\n", - "i: 117\n", - "unique: 117\n", - "i: 118\n", - "unique: 118\n", - "i: 119\n", - "unique: 119\n", - "i: 120\n", - "unique: 120\n", - "i: 121\n", - "unique: 121\n", - "i: 122\n", - "unique: 122\n", - "i: 123\n", - "unique: 123\n", - "i: 124\n", - "unique: 124\n", - "i: 125\n", - "unique: 125\n", - "map_fused_neurons: []\n", - "Calculating the number of neurons not found...\n", - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -414,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -428,66 +165,177 @@ "dtype('int32')" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')\n", + "viewer.add_labels(connected, name=\"connected\")\n", "connected.dtype" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -501,7 +349,7 @@ "dtype('int64')" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -512,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -522,42 +370,155 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { diff --git a/requirements.txt b/requirements.txt index ead0052c..834a225e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ black coverage isort +itk pytest pytest-qt sphinx @@ -18,6 +19,7 @@ matplotlib>=3.4.1 tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,scikit-image,itk,einops]>=0.9.0 +monai[nibabel,einops]>=1.0.1 pillow +scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index eb7b30a2..d8adc6ae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=0.9.0 + monai[nibabel,einops]>=1.0.1 itk tqdm nibabel From 2fa7e06c8c22e928b33225fea103ee24b3aeb8ec Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 053/577] black --- .../code_models/model_instance_seg.py | 21 ++++++++---- napari_cellseg3d/code_models/model_workers.py | 4 ++- .../code_plugins/plugin_model_inference.py | 8 +++-- napari_cellseg3d/config.py | 2 ++ .../dev_scripts/evaluate_labels.py | 33 +++++++++++-------- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 77e5c981..f3a04059 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -33,7 +33,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None + widget_parent: QWidget = None, ): """ Methods for instance segmentation @@ -56,7 +56,14 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), + ui.Slider( + 0, + 100, + 1, + divide_factor=100, + text_label="", + parent=None, + ), ) self.sliders.append(getattr(self, widget)) @@ -365,13 +372,13 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -411,13 +418,13 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -448,7 +455,7 @@ def __init__(self, widget_parent): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 7fd15c72..636f7acd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -542,7 +542,9 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct + instance_labels = np.swapaxes( + instance_labels, 0, 2 + ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index a6a90eb4..c9b59357 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -555,7 +555,9 @@ def start(self): self.instance_config = config.InstanceSegConfig( enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], ) self.post_process_config = config.PostProcessConfig( @@ -724,7 +726,9 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 107af8e6..6e9cc89e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -119,11 +119,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class PostProcessConfig: zoom: Zoom = Zoom() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..1aa52932 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -10,11 +10,14 @@ PERCENT_CORRECT = 0.7 + @dataclass class LabelInfo: gt_index: int model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + best_model_label_coverage: float = ( + 0.0 # ratio of pixels of the gt label correctly labelled + ) overall_gt_label_coverage: float = 0.0 # true positive ration of the model def get_correct_ratio(self): @@ -24,16 +27,25 @@ def get_correct_ratio(self): else: return None + def eval_model(gt_labels, model_labels, print_report=False): - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + report_list, new_labels, fused_labels = create_label_report( + gt_labels, model_labels + ) per_label_perfs = [] for report in report_list: if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + log.info( + f"Label {report.gt_index} : {report.model_labels_id_and_status}" + ) + log.info( + f"Best model label coverage : {report.best_model_label_coverage}" + ) + log.info( + f"Overall gt label coverage : {report.overall_gt_label_coverage}" + ) perf = report.get_correct_ratio() if perf is not None: @@ -43,8 +55,6 @@ def eval_model(gt_labels, model_labels, print_report=False): return per_label_perfs.mean(), new_labels, fused_labels - - def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -63,7 +73,6 @@ def create_label_report(gt_labels, model_labels): The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ - map_labels_existing = [] map_fused_neurons = {} "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" @@ -135,9 +144,7 @@ def create_label_report(gt_labels, model_labels): # log.debug(ratio) ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) + info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() best_model_lab_id = model_lab_on_gt[ np.argmax(ratio_model_lab_for_given_gt_lab) @@ -369,9 +376,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: From d914dfe978af63c870109277c1abda16cf19857f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 054/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 564 +++++++++--------- notebooks/assess_instance.ipynb | 290 ++++----- 2 files changed, 385 insertions(+), 469 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 1aa52932..3082e79f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,275 +1,15 @@ import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = ( - 0.0 # ratio of pixels of the gt label correctly labelled - ) - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report( - gt_labels, model_labels - ) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info( - f"Label {report.gt_index} : {report.model_labels_id_and_status}" - ) - log.info( - f"Best model label coverage : {report.best_model_label_coverage}" - ) - log.info( - f"Overall gt label coverage : {report.overall_gt_label_coverage}" - ) - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False ): """Evaluate the model performance. Parameters @@ -278,7 +18,7 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - do_print : bool + print_details : bool If True, print the results. visualize : bool If True, visualize the results. @@ -305,7 +45,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels + labels, model_labels, threshold_correct ) # calculate the number of neurons individually found @@ -351,33 +91,30 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") + log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") + log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") + log.info( + f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + ) log.info( - "Mean true positive ratio of the model: ", + f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" ) - log.info(mean_true_positive_ratio_model) log.info( - "Mean ratio of the neurons pixels correctly labelled: ", + f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" ) - log.info(mean_ratio_pixel_found) log.info( - "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" ) - log.info(mean_ratio_pixel_found_fused) log.info( - "Mean true positive ratio of the model for fused neurons: ", + f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" ) - log.info(mean_true_positive_ratio_model_fused) - log.info("Mean ratio of false pixel in artefacts: ") - log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -436,6 +173,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file @@ -464,6 +276,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons # if __name__ == "__main__": # """ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index d521c395..4bf89452 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,9 +4,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -22,6 +19,7 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", + " to_semantic,\n", ")" ] }, @@ -29,9 +27,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -50,12 +45,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -72,9 +69,7 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" + "viewer.add_labels(gt_labels_resized, name=\"gt\")" ] }, { @@ -84,9 +79,33 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -98,7 +117,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -110,48 +143,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "2\n" ] - }, - { - "data": { - "text/plain": [ - "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)\n", + "print(np.unique(gt_labels_resized).shape[0])" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -162,23 +168,22 @@ { "data": { "text/plain": [ - "dtype('int32')" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized,thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -190,21 +195,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:05,891 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -217,18 +225,10 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" + "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -251,21 +251,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:05,995 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -278,25 +281,17 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", + " prediction_resized, thres_small=2, rem_seed_thres=1\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -304,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -318,24 +313,24 @@ "(25, 64, 64)" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", + "voronoi = remove_small_objects(voronoi, 2)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { @@ -349,7 +344,7 @@ "dtype('int64')" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -360,104 +355,35 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(voronoi, return_counts=True)" + "# np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(gt_labels_resized, return_counts=True)" + "# np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { @@ -469,21 +395,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:06,360 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -496,18 +425,10 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -518,14 +439,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, "outputs": [], From e30bf2fafff95df347056aaed36280abcf347e43 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 055/577] Added pre-commit hooks --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 834a225e..3189e9c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pre-commit pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 52c2bbc7694a03b841127fa0bad2bc9d406e9303 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 056/577] Enfore pre-commit style --- .gitignore | 1 + .../_tests/test_plugin_inference.py | 2 - .../code_models/model_instance_seg.py | 8 +- .../code_plugins/plugin_model_inference.py | 3 - .../code_plugins/plugin_utilities.py | 1 - napari_cellseg3d/config.py | 3 - .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 23 ++++-- napari_cellseg3d/utils.py | 10 +-- notebooks/assess_instance.ipynb | 79 +++++++++++++------ notebooks/csv_cell_plot.ipynb | 2 - 12 files changed, 78 insertions(+), 56 deletions(-) diff --git a/.gitignore b/.gitignore index d08ff9f2..e86beea4 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ notebooks/full_plot.html *.csv *.png *.prof + diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 584ffd3b..e15958e6 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,8 +7,6 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST - - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index f3a04059..f83bfd4d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,5 +1,3 @@ -from __future__ import division -from __future__ import print_function from dataclasses import dataclass from typing import List import numpy as np @@ -10,6 +8,7 @@ from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread + # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -531,14 +530,13 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index c9b59357..483679ef 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,9 +10,6 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 1f0d598b..6e3b9981 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,7 +2,6 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QLayout from qtpy.QtWidgets import QSizePolicy from qtpy.QtWidgets import QVBoxLayout from qtpy.QtWidgets import QWidget diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 6e9cc89e..3d5f2a1a 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,10 +8,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3082e79f..a972fa69 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -5,11 +5,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -91,9 +95,15 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") - log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") - log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) if print_details: log.info(f"Neurons found: {neurons_found}") @@ -131,7 +141,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -276,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 7f4f3dc1..11d369a7 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -36,9 +36,7 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) + cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] @@ -413,17 +411,17 @@ def parse_default_path(possible_paths): def get_date_time(): """Get date and time in the following format : year_month_day_hour_minute_second""" - return "{:%Y_%m_%d_%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%Y_%m_%d_%H_%M_%S}" def get_time(): """Get time in the following format : hour:minute:second. NOT COMPATIBLE with file paths (saving with ":" is invalid)""" - return "{:%H:%M:%S}".format(datetime.now()) + return f"{datetime.now():%H:%M:%S}" def get_time_filepath(): """Get time in the following format : hour_minute_second. Compatible with saving""" - return "{:%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%H_%M_%S}" def load_images(dir_or_path, filetype="", as_folder: bool = False): diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 4bf89452..b8810301 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -47,7 +47,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -96,7 +96,10 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", - "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" ] }, { @@ -145,7 +148,7 @@ "text": [ "(25, 64, 64)\n", "(25, 64, 64)\n", - "2\n" + "125\n" ] } ], @@ -168,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -177,7 +180,7 @@ } ], "source": [ - "connected = binary_connected(prediction_resized,thres_small=2)\n", + "connected = binary_connected(prediction_resized, thres_small=2)\n", "viewer.add_labels(connected, name=\"connected\")" ] }, @@ -195,24 +198,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,891 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -225,7 +228,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -251,24 +262,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,995 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -281,7 +292,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -395,24 +414,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,360 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -425,7 +444,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 15, diff --git a/notebooks/csv_cell_plot.ipynb b/notebooks/csv_cell_plot.ipynb index 8b14fb8d..e00a9f1c 100644 --- a/notebooks/csv_cell_plot.ipynb +++ b/notebooks/csv_cell_plot.ipynb @@ -58,7 +58,6 @@ "outputs": [], "source": [ "def plot_data(data_path, x_inv=False, y_inv=False, z_inv=False):\n", - "\n", " data = pd.read_csv(data_path, index_col=False)\n", "\n", " x = data[\"Centroid x\"]\n", @@ -185,7 +184,6 @@ "outputs": [], "source": [ "def plotly_cells_stats(data):\n", - "\n", " init_notebook_mode() # initiate notebook for offline plot\n", "\n", " x = data[\"Centroid x\"]\n", From 43a5bdfb298cf119b0c5822e80fc8777c886e922 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 057/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index e86beea4..755de742 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,4 @@ notebooks/full_plot.html *.png *.prof + From 9c4c6eb427333ebe44671075c939d47683acb157 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 058/577] Version bump --- napari_cellseg3d/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 11e8de0e..736c7f72 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1,2 @@ __version__ = "0.0.2rc6" + From c8c9712176e97c4178d4231db397a60cfcb45b89 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 059/577] Updated project files --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4ecb1b86..ec6cbd8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "nibabel", "scikit-image", "pillow", + "pyclesperanto-prototype", "tqdm", "matplotlib", "vispy>=0.9.6", @@ -59,6 +60,7 @@ dev = [ "black", "ruff", "pre-commit", + ] docs = [ "sphinx", @@ -73,3 +75,4 @@ test = [ "tox", "twine", ] + From a5257eedd6f4121744d4d72febaaa2c518aee1cb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 09:45:17 +0200 Subject: [PATCH 060/577] Fixed missing parent error --- napari_cellseg3d/code_models/model_instance_seg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index f83bfd4d..ccdb5b18 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -448,7 +448,7 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent): + def __init__(self, widget_parent=None): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, From 6d2c79ea6ea24e0e24ffd06f565ef919c6c1389e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 061/577] Fixed wrong value in instance sliders --- .../code_models/model_instance_seg.py | 35 ++++++++++++------- .../code_plugins/plugin_model_inference.py | 1 + 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index ccdb5b18..979f861c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -138,6 +138,9 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -146,7 +149,7 @@ def voronoi_otsu( def binary_connected( - volume, + volume: np.array, thres=0.5, thres_small=3, ): @@ -158,8 +161,12 @@ def binary_connected( thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 """ + logger.debug( + f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" + ) + # if len(volume.shape) > 3: semantic = np.squeeze(volume) - foreground = semantic > thres # int(255 * thres) + foreground = np.where(semantic > thres, volume, 0) # int(255 * thres) segm = label(foreground) segm = remove_small_objects(segm, thres_small) @@ -202,6 +209,10 @@ def binary_watershed( rem_seed_thres (int): threshold for small seeds removal. Default : 3 """ + logger.debug( + f"Running watershed segmentation with thres_objects={thres_objects}, thres_seeding={thres_seeding}," + f" thres_small={thres_small} and rem_seed_thres={rem_seed_thres}" + ) semantic = np.squeeze(volume) seed_map = semantic > thres_seeding foreground = semantic > thres_objects @@ -407,8 +418,8 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( image, - self.sliders[0].value(), - self.sliders[1].value(), + self.sliders[0].slider_value, + self.sliders[1].slider_value, self.counters[0].value(), self.counters[1].value(), ) @@ -441,7 +452,7 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( - image, self.sliders[0].value(), self.counters[0].value() + image, self.sliders[0].slider_value, self.counters[0].value() ) @@ -501,7 +512,7 @@ def __init__(self, parent=None): """ super().__init__(parent) self.method_choice = ui.DropdownMenu( - INSTANCE_SEGMENTATION_METHOD_LIST.keys() + list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) self.methods = {} """Contains the instance of the method, with its name as key""" @@ -520,7 +531,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -530,8 +541,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() @@ -555,9 +568,7 @@ def run_method(self, volume): Returns: processed image from self._method """ - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() + method = self.methods[self.method_choice.currentText()] return method.run_method(volume) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 483679ef..fb6fb71c 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -184,6 +184,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) + self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## From 4271da2c571ccbab2180693f27f5dea4e1848ec5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 062/577] Removing dask-image --- .gitignore | 1 + napari_cellseg3d/dev_scripts/convert.py | 5 +- napari_cellseg3d/dev_scripts/view_brain.py | 2 +- napari_cellseg3d/dev_scripts/view_sample.py | 2 +- napari_cellseg3d/utils.py | 113 ++++++++++---------- notebooks/full_plot.ipynb | 3 +- setup.cfg | 1 + 7 files changed, 64 insertions(+), 63 deletions(-) diff --git a/.gitignore b/.gitignore index 755de742..f8547d92 100644 --- a/.gitignore +++ b/.gitignore @@ -107,3 +107,4 @@ notebooks/full_plot.html *.prof +*.prof diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 129c16be..479a07dd 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,7 +2,7 @@ import os import numpy as np -from dask_image.imread import imread +from tifffile import imread from tifffile import imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" @@ -19,8 +19,7 @@ filenames.append(os.path.basename(filename)) # print(os.path.basename(filename)) for file in paths: - img = imread(file) - image = img.compute() + image = imread(file) image[image >= 1] = 1 image = image.astype(np.uint16) diff --git a/napari_cellseg3d/dev_scripts/view_brain.py b/napari_cellseg3d/dev_scripts/view_brain.py index e5879638..145d4e45 100644 --- a/napari_cellseg3d/dev_scripts/view_brain.py +++ b/napari_cellseg3d/dev_scripts/view_brain.py @@ -1,5 +1,5 @@ import napari -from dask_image.imread import imread +from tifffile import imread y = imread("/Users/maximevidal/Documents/3drawdata/wholebrain.tif") diff --git a/napari_cellseg3d/dev_scripts/view_sample.py b/napari_cellseg3d/dev_scripts/view_sample.py index 329944ac..8e87f85c 100644 --- a/napari_cellseg3d/dev_scripts/view_sample.py +++ b/napari_cellseg3d/dev_scripts/view_sample.py @@ -1,5 +1,5 @@ import napari -from dask_image.imread import imread +from tifffile import imread # Visual x = imread( diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 11d369a7..4b04d536 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -4,9 +4,8 @@ from pathlib import Path import numpy as np -from dask_image.imread import imread as dask_imread -from pandas import DataFrame -from pandas import Series + +# from dask import delayed from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -17,7 +16,6 @@ # LOGGER.setLevel(logging.DEBUG) LOGGER.setLevel(logging.INFO) ############### - """ utils.py ==================================== @@ -278,51 +276,51 @@ def annotation_to_input(label_ermito): return anno -def check_csv(project_path, ext): - if not Path(Path(project_path) / Path(project_path).name).is_file(): - cols = [ - "project", - "type", - "ext", - "z", - "y", - "x", - "z_size", - "y_size", - "x_size", - "created_date", - "update_date", - "path", - "notes", - ] - df = DataFrame(index=[], columns=cols) - filename_pattern_original = Path(project_path) / Path( - f"dataset/Original_size/Original/*{ext}" - ) - images_original = dask_imread(filename_pattern_original) - z, y, x = images_original.shape - record = Series( - [ - Path(project_path).name, - "dataset", - ".tif", - 0, - 0, - 0, - z, - y, - x, - datetime.datetime.now(), - "", - Path(project_path) / Path("dataset/Original_size/Original"), - "", - ], - index=df.columns, - ) - df = df.append(record, ignore_index=True) - df.to_csv(Path(project_path) / Path(project_path).name) - else: - pass +# def check_csv(project_path, ext): +# if not Path(Path(project_path) / Path(project_path).name).is_file(): +# cols = [ +# "project", +# "type", +# "ext", +# "z", +# "y", +# "x", +# "z_size", +# "y_size", +# "x_size", +# "created_date", +# "update_date", +# "path", +# "notes", +# ] +# df = DataFrame(index=[], columns=cols) +# filename_pattern_original = Path(project_path) / Path( +# f"dataset/Original_size/Original/*{ext}" +# ) +# images_original = dask_imread(filename_pattern_original) +# z, y, x = images_original.shape +# record = Series( +# [ +# Path(project_path).name, +# "dataset", +# ".tif", +# 0, +# 0, +# 0, +# z, +# y, +# x, +# datetime.datetime.now(), +# "", +# Path(project_path) / Path("dataset/Original_size/Original"), +# "", +# ], +# index=df.columns, +# ) +# df = df.append(record, ignore_index=True) +# df.to_csv(Path(project_path) / Path(project_path).name) +# else: +# pass # def check_annotations_dir(project_path): @@ -457,7 +455,10 @@ def load_images(dir_or_path, filetype="", as_folder: bool = False): raise ValueError("If loading as a folder, filetype must be specified") if as_folder: - images_original = dask_imread(filename_pattern_original) + raise NotImplementedError( + "Loading as folder not implemented yet. Use napari to load as folder" + ) + # images_original = dask_imread(filename_pattern_original) else: images_original = tfl_imread( filename_pattern_original @@ -478,12 +479,12 @@ def load_images(dir_or_path, filetype="", as_folder: bool = False): # return base_label -def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): - images_label = load_images(mod_mask_dir, filetype, as_folder) - if as_folder: - images_label = images_label.compute() - base_label = images_label - return base_label +# def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): +# images_label = load_images(mod_mask_dir, filetype, as_folder) +# if as_folder: +# images_label = images_label.compute() +# base_label = images_label +# return base_label def save_stack(images, out_path, filetype=".png", check_warnings=False): diff --git a/notebooks/full_plot.ipynb b/notebooks/full_plot.ipynb index 857384d4..f804598e 100644 --- a/notebooks/full_plot.ipynb +++ b/notebooks/full_plot.ipynb @@ -10,8 +10,7 @@ "import matplotlib.pyplot as plt\n", "import os\n", "import numpy as np\n", - "from PIL import Image\n", - "from dask_image.imread import imread" + "from tifffile import imread" ] }, { diff --git a/setup.cfg b/setup.cfg index d8adc6ae..2420dd1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ package_dir = # add your package requirements here # the long list after monai is due to monai optional requirements... Not sure how to know in advance which readers it wil use +# FIXME remove dask install_requires = numpy napari[all]>=0.4.14 From 8b44b398a344d3aa4c5eead8af8e036a82102004 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 063/577] Fixed erroneous dtype conversion --- napari_cellseg3d/code_models/model_instance_seg.py | 13 +++++++++++-- napari_cellseg3d/code_plugins/plugin_convert.py | 12 ++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 979f861c..436135a1 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -137,12 +137,12 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - semantic = np.squeeze(volume) + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) instance = cle.voronoi_otsu_labeling( - semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) return np.array(instance) @@ -489,6 +489,15 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + + ################ + # For debugging + # import napari + # view = napari.Viewer() + # view.add_image(image) + # napari.run() + ################ + return self.function( image, self.counters[0].value(), diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 7a59dcf0..c1493fa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -150,7 +150,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) save_layer( @@ -168,7 +168,7 @@ def _start(self): elif self.folder_choice.isChecked(): if len(self.images_filepaths) != 0: images = [ - utils.resize(np.array(imread(file), dtype=np.int16), zoom) + utils.resize(np.array(imread(file)), zoom) for file in self.images_filepaths ] save_folder( @@ -249,7 +249,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( @@ -330,7 +330,7 @@ def _start(self): if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) semantic = to_semantic(data) save_layer( @@ -414,7 +414,7 @@ def _start(self): if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) instance = self.instance_widgets.run_method(data) save_layer( @@ -509,7 +509,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( From ca5349dd6ee825be905249da4808d0d0b6359dd9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 064/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index a5b6fd94..8fd297c2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,3 +1,7 @@ +from pathlib import Path +from tifffile import imread +import numpy as np + from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, @@ -8,9 +12,15 @@ def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + image = imread(im_path) + view.add_image(image) + view.add_labels(image.astype(np.uint8)) + view.window.add_dock_widget(widget) for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): widget.utils_choice.setCurrentIndex(i) assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + widget.utils_widgets[i]._start() From a485d22ad0c373b766a15f42b843c172908aade7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:38:13 +0200 Subject: [PATCH 065/577] Temporary test action patch --- .github/workflows/test_and_deploy.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5dcd11ae..ea0a1e46 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -8,12 +8,14 @@ on: branches: - main - npe2 + - cy/voronoi-otsu tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - npe2 + - cy/voronoi-otsu workflow_dispatch: jobs: From 852c52b281c380ac7a41e49ba81352061366805e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:50:16 +0200 Subject: [PATCH 066/577] Update plugin_convert.py --- napari_cellseg3d/code_plugins/plugin_convert.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index c1493fa4..6908b7aa 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -34,7 +34,7 @@ def save_folder(results_path, folder_name, images, image_paths): image_paths: list of filenames of images """ results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False) + results_folder.mkdir(exist_ok=False, parents=True) for file, image in zip(image_paths, images): path = results_folder / Path(file).name @@ -143,7 +143,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -242,7 +242,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -324,7 +324,7 @@ def _build(self): ) def _start(self): - Path(self.results_path).mkdir(exist_ok=True) + Path(self.results_path).mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -408,7 +408,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -502,7 +502,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.binarize_counter.value() if self.layer_choice: From eebc2f55d551e7b0d2a5da91f666732c2fbaafae Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:02:47 +0200 Subject: [PATCH 067/577] Update tox.ini Added pocl for testing on GH Actions --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 3409f43c..162e2c20 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = magicgui pytest-qt qtpy + pocl ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From ea70143cf7b827fabb85e67e2763ac771c31bac9 Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 068/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 162e2c20..5eb6558d 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl + pocl-binary-distribution ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From c00229c2c5a0d988879addaefaff1eeb22f034ce Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:18:52 +0200 Subject: [PATCH 069/577] Found existing pocl --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 5eb6558d..ac7dafcf 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl-binary-distribution + pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 8c9c748c3ab37d90e0ae82c35f9ed5cdd6bc2c0d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:41:23 +0200 Subject: [PATCH 070/577] Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime --- napari_cellseg3d/_tests/test_plugin_utils.py | 5 +++++ tox.ini | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 8fd297c2..b2d9de52 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -23,4 +23,9 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + widget.utils_widgets[i]._start() diff --git a/tox.ini b/tox.ini index ac7dafcf..030d7437 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pyopencl[pocl] +; pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 1d847a7c14ebbb43910e65647000459caab791ab Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 071/577] Relabeling tests --- .gitignore | 6 +- napari_cellseg3d/_tests/res/test_labels.tif | Bin 0 -> 2026 bytes .../_tests/test_labels_correction.py | 51 ++++++++++ .../dev_scripts/artefact_labeling.py | 93 ++++++++---------- .../dev_scripts/correct_labels.py | 75 +++++++++----- 5 files changed, 151 insertions(+), 74 deletions(-) create mode 100644 napari_cellseg3d/_tests/res/test_labels.tif create mode 100644 napari_cellseg3d/_tests/test_labels_correction.py diff --git a/.gitignore b/.gitignore index f8547d92..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -106,5 +106,7 @@ notebooks/full_plot.html *.png *.prof - -*.prof +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/res/test_labels.tif b/napari_cellseg3d/_tests/res/test_labels.tif new file mode 100644 index 0000000000000000000000000000000000000000..0486d789ea658acc32616b40869833accf8d01d7 GIT binary patch literal 2026 zcmcK5yJ}QX6b9gPW+oRk7ZaUC6EDMfA4AYa#WzSNSc-*3f&q(wHX^pxK7x-TK7$V+ zh-hgcUQztNum?}HQam9)Yn^rd*V>ys8yll)x~i)As;YZc9c?nG8+xbi?%D^jcZTs? z;A_lNkw1zQI~mBsOB}G_#)iYWBJ~`{jpd=(^f$j8o7U4Q+J#zTzQ_J1_z;*uteC68 z$zUP)5+6=ue*NfXR{vChyE>l&{pDQ@Rs-|ldNz=U59v(k#{#wVv17fKBh6$ldYFA! zj2TD_w8=&Bcb7Z$2DI=OnVnep55ES3J|ZCRZ7^|q`;Z@w+f_hcAf uO7Fq{VSFjiRU3?7w#N8*ON^i72d14J-{`ip<7-oGF@Dt&<6PlCcKj3h Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 072/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4202f04e..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: rev: 5.12.0 hooks: - id: isort - args: ["--profile", "black", --line-length=72] + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From aba48679a0795a0ab815c79b1c73cd643f7a3feb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 073/577] Run full suite of pre-commit hooks --- README.md | 2 +- napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/_tests/pytest.ini | 2 +- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/model_instance_seg.py | 3 +-- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 2 +- 9 files changed, 29 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 415a4f3d..7fa77422 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 814cca2e..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 \ No newline at end of file +qt_api=pyqt5 diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index b2d9de52..8dcd3c7e 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 436135a1..6d0dc13d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -14,8 +14,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -489,7 +489,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index a972fa69..ee9919b6 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,7 @@ +import napari import numpy as np import pandas as pd from tqdm import tqdm -import napari from napari_cellseg3d.utils import LOGGER as log From 2bed19dcc0f241efa1a15604d8487393161d0e47 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:08:38 +0200 Subject: [PATCH 074/577] Enforce style --- napari_cellseg3d/__init__.py | 1 - napari_cellseg3d/_tests/test_plugin_inference.py | 1 + napari_cellseg3d/_tests/test_plugin_utils.py | 4 +--- napari_cellseg3d/code_models/model_instance_seg.py | 8 +++++--- napari_cellseg3d/code_models/models/unet/model.py | 4 +--- napari_cellseg3d/code_plugins/plugin_convert.py | 2 ++ napari_cellseg3d/code_plugins/plugin_crop.py | 4 +--- napari_cellseg3d/code_plugins/plugin_review.py | 4 +--- napari_cellseg3d/code_plugins/plugin_utilities.py | 4 +--- napari_cellseg3d/config.py | 1 - napari_cellseg3d/interface.py | 5 +---- pyproject.toml | 1 - 12 files changed, 14 insertions(+), 25 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 736c7f72..11e8de0e 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1,2 +1 @@ __version__ = "0.0.2rc6" - diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index e15958e6..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,6 +7,7 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 8dcd3c7e..cbfd97b2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -4,9 +4,7 @@ from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import ( - UTILITIES_WIDGETS, -) +from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 6d0dc13d..cc362eac 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import List + import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget @@ -9,14 +10,15 @@ from skimage.segmentation import watershed from tifffile import imread -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index c5cc78d3..6cc76be6 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -6,9 +6,7 @@ from napari_cellseg3d.code_models.models.unet.buildingblocks import ( create_encoders, ) -from napari_cellseg3d.code_models.models.unet.buildingblocks import ( - DoubleConv, -) +from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6908b7aa..ed1a43df 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,5 +1,6 @@ import warnings from pathlib import Path + import napari import numpy as np from qtpy.QtWidgets import QSizePolicy @@ -354,6 +355,7 @@ def _start(self): self.images_filepaths, ) + class ToInstanceUtils(BasePluginFolder): """ Widget to convert semantic labels to instance labels diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 9f4d80b6..07885236 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -11,9 +11,7 @@ # local from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_plugins.plugin_base import ( - BasePluginSingleImage, -) +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage DEFAULT_CROP_SIZE = 64 logger = utils.LOGGER diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 0044c8e2..a803dfd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -19,9 +19,7 @@ from napari_cellseg3d import config from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_plugins.plugin_base import ( - BasePluginSingleImage, -) +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager warnings.formatwarning = utils.format_Warning diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 6e3b9981..c962717e 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -9,9 +9,7 @@ # local import napari_cellseg3d.interface as ui from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils -from napari_cellseg3d.code_plugins.plugin_convert import ( - RemoveSmallUtils, -) +from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 3d5f2a1a..ab3dba39 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -10,7 +10,6 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index a854905b..bb2a1efb 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -7,11 +7,8 @@ import napari # Qt -from qtpy import QtCore - # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt +from qtpy import QtCore from qtpy.QtCore import QObject from qtpy.QtCore import Qt from qtpy.QtCore import QUrl diff --git a/pyproject.toml b/pyproject.toml index ec6cbd8c..2814960b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,4 +75,3 @@ test = [ "tox", "twine", ] - From 1f51af81c34264384ceafe1a8f6045fe012f9763 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 24 Mar 2023 17:08:44 +0100 Subject: [PATCH 075/577] Model class refactor --- docs/res/guides/custom_model_template.rst | 24 -- .../_tests/test_weight_download.py | 4 +- napari_cellseg3d/code_models/model_workers.py | 364 ++++++++++-------- .../code_models/models/model_SegResNet.py | 48 ++- .../code_models/models/model_SwinUNetR.py | 36 +- .../code_models/models/model_TRAILMAP.py | 39 +- .../code_models/models/model_TRAILMAP_MS.py | 27 +- .../code_models/models/model_VNet.py | 56 +-- .../code_models/models/model_test.py | 24 +- .../code_plugins/plugin_model_inference.py | 143 +++---- .../code_plugins/plugin_model_training.py | 4 +- .../code_plugins/plugin_review.py | 2 +- napari_cellseg3d/config.py | 65 +++- napari_cellseg3d/interface.py | 18 +- napari_cellseg3d/utils.py | 18 +- notebooks/assess_instance.ipynb | 121 +++--- requirements.txt | 6 +- setup.cfg | 2 +- 18 files changed, 562 insertions(+), 439 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index afbcd98a..9bad49b0 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -10,28 +10,4 @@ To add a custom model, you will need a **.py** file with the following structure :: - def get_net(): - return ModelClass # should return the class of the model, - # for example SegResNet or UNET - - def get_weights_file(): - return "weights_file.pth" # name of the weights file for the model, - # which should be in *napari_cellseg3d/models/pretrained* - - - def get_output(model, input): - out = model(input) # should return the model's output as [C, N, D,H,W] - # (C: channel, N, batch size, D,H,W : depth, height, width) - return out - - - def get_validation(model, val_inputs): - val_outputs = model(val_inputs) # should return the proper type for validation - # with sliding_window_inference from MONAI - return val_outputs - - - def ModelClass(x1,x2...): - # your Pytorch model here... - return results # should return as [C, N, D,H,W] diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index b8f0d748..51189e4b 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR from napari_cellseg3d.code_models.model_workers import WeightsDownloader @@ -6,6 +6,6 @@ def test_weight_download(): downloader = WeightsDownloader() downloader.download_weights("test", "test.pth") - result_path = WEIGHTS_DIR / "test.pth" + result_path = PRETRAINED_WEIGHTS_DIR / "test.pth" assert result_path.is_file() diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 636f7acd..25a9da19 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -2,8 +2,7 @@ from dataclasses import dataclass from math import ceil from pathlib import Path -from typing import List -from typing import Optional +import typing as t import numpy as np import torch @@ -39,6 +38,8 @@ # threads from napari.qt.threading import GeneratorWorker + +# from napari.qt.threading import thread_worker from napari.qt.threading import WorkerBaseSignals # Qt @@ -65,14 +66,16 @@ # https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ # https://napari-staging-site.github.io/guides/stable/threading.html -WEIGHTS_DIR = Path(__file__).parent.resolve() / Path("models/pretrained") -logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {WEIGHTS_DIR}") +PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( + "models/pretrained" +) +logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") class WeightsDownloader: """A utility class the downloads the weights of a model when needed.""" - def __init__(self, log_widget: Optional[ui.Log] = None): + def __init__(self, log_widget: t.Optional[ui.Log] = None): """ Creates a WeightsDownloader, optionally with a log widget to display the progress. @@ -94,11 +97,11 @@ def download_weights(self, model_name: str, model_weights_filename: str): import tarfile import urllib.request - def show_progress(count, block_size, total_size): + def show_progress(_, block_size, __): # count, block_size, total_size pbar.update(block_size) logger.info("*" * 20) - pretrained_folder_path = WEIGHTS_DIR + pretrained_folder_path = PRETRAINED_WEIGHTS_DIR json_path = pretrained_folder_path / Path("pretrained_model_urls.json") check_path = pretrained_folder_path / Path(model_weights_filename) @@ -168,12 +171,17 @@ def safe_extract( class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_""" # TODO link ? + Separate from Worker instances as indicated `here`_ + + .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + """ # TODO link ? log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" warn_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" + error_signal = Signal(Exception, str) + """qtpy.QtCore.Signal: signal to be sent when some error should be emitted in main thread""" # Should not be an instance variable but a class variable, not defined in __init__, see # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect @@ -204,33 +212,24 @@ def __init__( ): """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function. - Args: - * config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * device: cuda or cpu device to use for torch - - * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance - - * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom - - * results_path: the path to save the results to - - * filetype: the file extension to use when saving, - - * transforms: a dict containing transforms to perform at various times. + The config contains the following attributes: + * device: cuda or cpu device to use for torch + * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance + * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom + * results_path: the path to save the results to + * filetype: the file extension to use when saving, + * transforms: a dict containing transforms to perform at various times. + * instance: a dict containing parameters regarding instance segmentation + * use_window: use window inference with specific size or whole image + * window_infer_size: size of window if use_window is True + * keep_on_cpu: keep images on CPU or no + * stats_csv: compute stats on cells and save them to a csv file + * images_filepaths: the paths to the images of the dataset + * layer: the layer to run inference on - * instance: a dict containing parameters regarding instance segmentation - - * use_window: use window inference with specific size or whole image - - * window_infer_size: size of window if use_window is True - - * keep_on_cpu: keep images on CPU or no - - * stats_csv: compute stats on cells and save them to a csv file - - * images_filepaths: the paths to the images of the dataset + Args: + * worker_config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * layer: the layer to run inference on Note: See :py:func:`~self.inference` """ @@ -238,6 +237,7 @@ def __init__( self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self.config = worker_config @@ -270,6 +270,21 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Raises an error in main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + + self.log_signal.emit("!" * 20) + self.log_signal.emit("Error occured") + # self.log_signal.emit(msg) + # self.log_signal.emit(str(exception)) + + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + yield exception + # self.quit() + def log_parameters(self): config = self.config @@ -398,7 +413,7 @@ def load_layer(self): ) # for anisotropy to be monai-like, i.e. zyx # FIXME rotation not always correct dims_check = volume.shape - # self.log("\nChecking dimensions...") + self.log("Checking dimensions...") pad = utils.get_padding_dim(dims_check) # logger.debug(volume.shape) @@ -449,55 +464,61 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - def model_output(inputs): - return post_process_transforms( - self.config.model_info.get_model().get_output(model, inputs) - ) - if self.config.keep_on_cpu: dataset_device = "cpu" else: dataset_device = self.config.device - window_size = self.config.sliding_window_config.window_size - window_overlap = self.config.sliding_window_config.window_overlap - - # FIXME - # import sys - - # old_stdout = sys.stdout - # old_stderr = sys.stderr - - # sys.stdout = self.downloader.log_widget - # sys.stdout = self.downloader.log_widget - - outputs = sliding_window_inference( - inputs, - roi_size=[window_size, window_size, window_size], - sw_batch_size=1, # TODO add param - predictor=model_output, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) - - # sys.stdout = old_stdout - # sys.stderr = old_stderr - - out = outputs.detach().cpu() - - if aniso_transform is not None: - out = aniso_transform(out) - - if post_process: - out = np.array(out).astype(np.float32) - out = np.squeeze(out) - return out + if self.config.sliding_window_config.is_enabled(): + window_size = self.config.sliding_window_config.window_size + window_size = [window_size, window_size, window_size] + window_overlap = self.config.sliding_window_config.window_overlap else: - return out + window_size = None + window_overlap = 0 + try: + # logger.debug(f"model : {model}") + logger.debug(f"inputs shape : {inputs.shape}") + logger.debug(f"inputs type : {inputs.dtype}") + try: + # outputs = model(inputs) + + def model_output_wrapper(inputs): + result = model(inputs) + return post_process_transforms(result) + + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) + except Exception as e: + logger.error(e, exc_info=True) + logger.debug("failed to run sliding window inference") + self.raise_error(e, "Error during sliding window inference") + logger.debug(f"Inference output shape: {outputs.shape}") + self.log("Post-processing...") + out = outputs.detach().cpu().numpy() + if aniso_transform is not None: + out = aniso_transform(out) + if post_process: + out = np.array(out).astype(np.float32) + out = np.squeeze(out) + return out + else: + return out + except Exception as e: + logger.error(e, exc_info=True) + self.raise_error(e, "Error during sliding window inference") + # sys.stdout = old_stdout + # sys.stderr = old_stderr - def create_result_dict( # FIXME replace with result class + def create_inference_result( self, semantic_labels, instance_labels, @@ -573,7 +594,10 @@ def save_image( + f"_{time}_" + self.config.filetype ) - imwrite(file_path, image) + try: + imwrite(file_path, image) + except ValueError as e: + self.raise_error(e, "Error during image saving") filename = Path(file_path).stem if from_layer: @@ -638,7 +662,7 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): self.log(f"Inference completed on image n°{i+1}") - return self.create_result_dict( + return self.create_inference_result( out, instance_labels, from_layer=False, @@ -649,9 +673,7 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): def stats_csv(self, instance_labels): if self.config.compute_stats: - stats = volume_stats( - instance_labels - ) # TODO test with area mesh function + stats = volume_stats(instance_labels) return stats # except ValueError as e: @@ -677,13 +699,14 @@ def inference_on_layer(self, image, model, post_process_transforms): instance_labels, stats = self.get_instance_result(out, from_layer=True) - return self.create_result_dict( + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels, from_layer=True, stats=stats, ) + # @thread_worker(connect={"errored": self.raise_error}) def inference(self): """ Requires: @@ -726,35 +749,68 @@ def inference(self): try: dims = self.config.model_info.model_input_size - # self.log(f"MODEL DIMS : {dims}") + self.log(f"MODEL DIMS : {dims}") model_name = self.config.model_info.name model_class = self.config.model_info.get_model() - self.log(model_name) + self.log(f"Model name : {model_name}") weights_config = self.config.weights_config post_process_config = self.config.post_process_config - if model_name == "SegResNet": - model = model_class.get_net( - input_image_size=[ - dims, - dims, - dims, - ], # TODO FIX ! find a better way & remove model-specific code + # try: + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[128, 128, 128], + ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, ) - elif model_name == "SwinUNetR": - model = model_class.get_net( - img_size=[dims, dims, dims], - use_checkpoint=False, + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - else: - model = model_class.get_net() - model = model.to(self.config.device) + model.load_state_dict( + torch.load( + weights, + map_location=self.config.device, + ) + ) + self.log("Done") + # except Exception as e: + # self.raise_error(e, "Issue loading weights") + # except Exception as e: + # self.raise_error(e, "Issue instantiating model") + + # if model_name == "SegResNet": + # model = model_class( + # input_image_size=[ + # dims, + # dims, + # dims, + # ], + # ) + # elif model_name == "SwinUNetR": + # model = model_class( + # img_size=[dims, dims, dims], + # use_checkpoint=False, + # ) + # else: + # model = model_class.get_net() self.log_parameters() - model.to(self.config.device) - # load_transforms = Compose( # [ # LoadImaged(keys=["image"]), @@ -775,25 +831,6 @@ def inference(self): AsDiscrete(threshold=t), EnsureType() ) - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path - else: - self.downloader.download_weights( - model_name, - model_class.get_weights_file(), - ) - weights = str( - WEIGHTS_DIR / Path(model_class.get_weights_file()) - ) - model.load_state_dict( - torch.load( - weights, - map_location=self.config.device, - ) - ) - self.log("Done") - is_folder = self.config.images_filepaths is not None is_layer = self.config.layer is not None @@ -818,6 +855,9 @@ def inference(self): else: raise ValueError("No data has been provided. Aborting.") + if model is None: + raise ValueError("Model is None") + model.eval() with torch.no_grad(): ################################ @@ -833,9 +873,10 @@ def inference(self): input_image, model, post_process_transforms ) model.to("cpu") - + # self.quit() except Exception as e: - self.log(f"Error during inference : {e}") + logger.error(e, exc_info=True) + self.raise_error(e, "Inference failed") self.quit() finally: self.quit() @@ -845,10 +886,10 @@ def inference(self): class TrainingReport: show_plot: bool = True epoch: int = 0 - loss_values: List = None - validation_metric: List = None + loss_values: t.Dict = None # TODO(cyril) : change to dict and unpack different losses for e.g. WNet with several losses + validation_metric: t.List = None weights: np.array = None - images: List[np.array] = None + images: t.List[np.array] = None class TrainingWorker(GeneratorWorker): @@ -900,6 +941,7 @@ def __init__( self._signals = LogSignal() self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self._weight_error = False ############################################# @@ -925,6 +967,14 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Sends an error to main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + self.quit() + def log_parameters(self): self.log("-" * 20) self.log("Parameters summary :\n") @@ -1054,29 +1104,14 @@ def train(self): do_sampling = self.config.sampling - if model_name == "SegResNet": - if do_sampling: - size = self.config.sample_size - else: - size = check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - input_image_size=utils.get_padding_dim(size), - # out_channels=1, - # dropout_prob=0.3, - ) - elif model_name == "SwinUNetR": - if do_sampling: - size = self.sample_size - else: - size = check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - img_size=utils.get_padding_dim(size), - use_checkpoint=True, - ) + if do_sampling: + size = self.config.sample_size else: - model = model_class.get_net() # get an instance of the model + size = check + + model = model_class( # FIXME check if correct + input_img_size=utils.get_padding_dim(size), use_checkpoint=True + ) model = model.to(self.config.device) epoch_loss_values = [] @@ -1216,7 +1251,11 @@ def train(self): else: load_whole_images = Compose( [ - LoadImaged(keys=["image", "label"]), + LoadImaged( + keys=["image", "label"], + # image_only=True, + # reader=WSIReader(backend="tifffile") + ), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="PLI"), SpatialPadd( @@ -1263,9 +1302,9 @@ def train(self): if weights_config.custom: if weights_config.use_pretrained: - weights_file = model_class.get_weights_file() + weights_file = model_class.weights_file self.downloader.download_weights(model_name, weights_file) - weights = WEIGHTS_DIR / Path(weights_file) + weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) weights_config.path = weights else: weights = str(Path(weights_config.path)) @@ -1279,6 +1318,7 @@ def train(self): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") + logger.error(e, exc_info=True) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" @@ -1326,7 +1366,7 @@ def train(self): batch_data["label"].to(device), ) optimizer.zero_grad() - outputs = model_class.get_output(model, inputs) + outputs = model(inputs) # self.log(f"Output dimensions : {outputs.shape}") loss = self.config.loss_function(outputs, labels) loss.backward() @@ -1357,10 +1397,24 @@ def train(self): val_data["image"].to(device), val_data["label"].to(device), ) - - val_outputs = model_class.get_validation( - model, val_inputs + self.log("Performing validation...") + try: + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) + except Exception as e: + self.raise_error(e, "Error during validation") + logger.debug( + f"val_outputs shape : {val_outputs.shape}" ) + # val_outputs = model(val_inputs) pred = decollate_batch(val_outputs) @@ -1407,7 +1461,7 @@ def train(self): weights=model.state_dict(), images=checkpoint_output, ) - + self.log("Validation completed") yield train_report weights_filename = ( @@ -1440,7 +1494,7 @@ def train(self): model.to("cpu") except Exception as e: - self.log(f"Error in training : {e}") + self.raise_error(e, "Error in training") self.quit() finally: self.quit() diff --git a/napari_cellseg3d/code_models/models/model_SegResNet.py b/napari_cellseg3d/code_models/models/model_SegResNet.py index 8856e18d..8b6e6e65 100644 --- a/napari_cellseg3d/code_models/models/model_SegResNet.py +++ b/napari_cellseg3d/code_models/models/model_SegResNet.py @@ -1,21 +1,33 @@ from monai.networks.nets import SegResNetVAE -def get_net(input_image_size, out_channels=1, dropout_prob=0.3): - return SegResNetVAE( - input_image_size, out_channels=out_channels, dropout_prob=dropout_prob - ) - - -def get_weights_file(): - return "SegResNet.pth" - - -def get_output(model, input): - out = model(input)[0] - return out - - -def get_validation(model, val_inputs): - val_outputs = model(val_inputs) - return val_outputs[0] +class SegResNet_(SegResNetVAE): + use_default_training = True + weights_file = "SegResNet.pth" + + def __init__( + self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs + ): + super().__init__( + input_img_size, + out_channels=out_channels, + dropout_prob=dropout_prob, + ) + + def forward(self, x): + res = SegResNetVAE.forward(self, x) + # logger.debug(f"SegResNetVAE.forward: {res[0].shape}") + return res[0] + + def get_model_test(self, size): + return SegResNetVAE( + size, in_channels=1, out_channels=1, dropout_prob=0.3 + ) + + # def get_output(model, input): + # out = model(input)[0] + # return out + + # def get_validation(model, val_inputs): + # val_outputs = model(val_inputs) + # return val_outputs[0] diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 532aeb89..fe4d380c 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,25 +1,23 @@ -import torch from monai.networks.nets import SwinUNETR -def get_weights_file(): - return "Swin64_best_metric.pth" +class SwinUNETR_(SwinUNETR): + use_default_training = True + weights_file = "Swin64_best_metric.pth" + def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs + ) -def get_net(img_size, use_checkpoint=True): - return SwinUNETR( - img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - ) + # def get_output(self, input): + # out = self(input) + # return torch.sigmoid(out) - -def get_output(model, input): - out = model(input) - return torch.sigmoid(out) - - -def get_validation(model, val_inputs): - return model(val_inputs) + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 09de2a26..8a108e37 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -2,28 +2,8 @@ from torch import nn -def get_weights_file(): - # model additionally trained on Mathis/Wyss mesoSPIM data - return "TRAILMAP_PyTorch.pth" - # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them - - -def get_net(): - return TRAILMAP(1, 1) - - -def get_output(model, input): - out = model(input) - - return out - - -def get_validation(model, val_inputs): - return model(val_inputs) - - class TRAILMAP(nn.Module): - def __init__(self, in_ch, out_ch): + def __init__(self, in_ch, out_ch, *args, **kwargs): super().__init__() self.conv0 = self.encoderBlock(in_ch, 32, 3) # input self.conv1 = self.encoderBlock(32, 64, 3) # l1 @@ -112,3 +92,20 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) return out + + +class TRAILMAP_(TRAILMAP): + use_default_training = True + weights_file = "TRAILMAP_PyTorch.pth" # model additionally trained on Mathis/Wyss mesoSPIM data + # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + # def get_output(model, input): + # out = model(input) + # + # return out + + # def get_validation(model, val_inputs): + # return model(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 0fc68d34..e3ca00a6 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,20 +1,21 @@ from napari_cellseg3d.code_models.models.unet.model import UNet3D -def get_weights_file(): - # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) - return "TRAILMAP_MS_best_metric_epoch_26.pth" - - -def get_net(): - return UNet3D(1, 1) +class TRAILMAP_MS_(UNet3D): + use_default_training = True + weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth" + # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) -def get_output(model, input): - out = model(input) - - return out + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + # def get_output(self, input): + # out = self(input) -def get_validation(model, val_inputs): - return model(val_inputs) + # return out + # + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 0c854832..41554e80 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -1,29 +1,33 @@ -from monai.inferers import sliding_window_inference from monai.networks.nets import VNet -def get_net(): - return VNet() - - -def get_weights_file(): - return "VNet_40e.pth" - - -def get_output(model, input): - out = model(input) - return out - - -def get_validation(model, val_inputs): - roi_size = (64, 64, 64) - sw_batch_size = 1 - val_outputs = sliding_window_inference( - val_inputs, - roi_size, - sw_batch_size, - model, - mode="gaussian", - overlap=0.7, - ) - return val_outputs +class VNet_(VNet): + use_default_training = True + weights_file = "VNet_40e.pth" + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError: + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) + + # def get_output(self, input): + # out = self(input) + # return out + + # def get_validation(self, val_inputs): # FIXME standardize + # roi_size = (64, 64, 64) + # sw_batch_size = 1 + # val_outputs = sliding_window_inference( + # val_inputs, + # roi_size, + # sw_batch_size, + # self, + # # mode="gaussian", + # # overlap=0.7, + # ) + # return val_outputs diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 5871c4a7..1ccac3da 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -2,26 +2,22 @@ from torch import nn -def get_weights_file(): - return "test.pth" - - class TestModel(nn.Module): - def __init__(self): + use_default_training = True + weights_file = "test.pth" + + def __init__(self, **kwargs): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) - def get_net(self): - return self - - def get_output(self, _, input): - return input + # def get_output(self, _, input): + # return input - def get_validation(self, val_inputs): - return val_inputs + # def get_validation(self, val_inputs): + # return val_inputs # if __name__ == "__main__": @@ -29,8 +25,8 @@ def get_validation(self, val_inputs): # model = TestModel() # model.train() # model.zero_grad() -# from napari_cellseg3d.config import WEIGHTS_DIR +# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR # torch.save( # model.state_dict(), -# WEIGHTS_DIR + f"/{get_weights_file()}" +# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" # ) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index fb6fb71c..33b18eda 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -160,6 +160,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" ) + self.window_size_choice.setCurrentIndex(3) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -602,10 +603,13 @@ def start(self): self.worker.set_download_log(self.log) self.worker.started.connect(self.on_start) + self.worker.log_signal.connect(self.log.print_and_log) self.worker.warn_signal.connect(self.log.warn) + self.worker.error_signal.connect(self.log.error) + self.worker.yielded.connect(partial(self.on_yield)) # - self.worker.errored.connect(partial(self.on_yield)) + self.worker.errored.connect(partial(self.on_error)) self.worker.finished.connect(self.on_finish) if self.get_device(show=False) == "cuda": @@ -642,15 +646,18 @@ def on_start(self): self.log.print_and_log(f"Saving results to : {self.results_path}") self.log.print_and_log("Worker is running...") - def on_error(self): - """Catches errors and tries to clean up. TODO : upgrade""" + def on_error(self, error): + """Catches errors and tries to clean up.""" + self.log.print_and_log("!" * 20) self.log.print_and_log("Worker errored...") - self.log.print_and_log("Trying to clean up...") + self.log.error(error) + # self.log.print_and_log("Trying to clean up...") + self.worker.quit() self.btn_start.setText("Start") self.btn_close.setVisible(True) - self.worker = None self.worker_config = None + self.worker = None self.empty_cuda_cache() def on_finish(self): @@ -673,85 +680,91 @@ def on_yield(self, result: InferenceResult): data (dict): dict yielded by :py:func:`~inference()`, contains : "image_id" : index of the returned image, "original" : original volume used for inference, "result" : inference result widget (QWidget): widget for accessing attributes """ + + if isinstance(result, Exception): + self.on_error(result) + # raise result # viewer, progress, show_res, show_res_number, zoon, show_original # check that viewer checkbox is on and that max number of displays has not been reached. # widget.log.print_and_log(result) + try: + image_id = result.image_id + model_name = result.model_name + if self.worker_config.images_filepaths is not None: + total = len(self.worker_config.images_filepaths) + else: + total = 1 - image_id = result.image_id - model_name = result.model_name - if self.worker_config.images_filepaths is not None: - total = len(self.worker_config.images_filepaths) - else: - total = 1 + viewer = self._viewer - viewer = self._viewer + pbar_value = image_id // total + if pbar_value == 0: + pbar_value = 1 - pbar_value = image_id // total - if pbar_value == 0: - pbar_value = 1 + self.progress.setValue(100 * pbar_value) - self.progress.setValue(100 * pbar_value) + if ( + self.config.show_results + and image_id <= self.config.show_results_count + ): + zoom = self.worker_config.post_process_config.zoom.zoom_values - if ( - self.config.show_results - and image_id <= self.config.show_results_count - ): - zoom = self.worker_config.post_process_config.zoom.zoom_values + viewer.dims.ndisplay = 3 + viewer.scale_bar.visible = True - viewer.dims.ndisplay = 3 - viewer.scale_bar.visible = True + if self.config.show_original and result.original is not None: + viewer.add_image( + result.original, + colormap="inferno", + name=f"original_{image_id}", + scale=zoom, + opacity=0.7, + ) + + out_colormap = "twilight" + if self.worker_config.post_process_config.thresholding.enabled: + out_colormap = "turbo" - if self.config.show_original and result.original is not None: viewer.add_image( - result.original, - colormap="inferno", - name=f"original_{image_id}", - scale=zoom, - opacity=0.7, + result.result, + colormap=out_colormap, + name=f"pred_{image_id}_{model_name}", + opacity=0.8, ) - out_colormap = "twilight" - if self.worker_config.post_process_config.thresholding.enabled: - out_colormap = "turbo" - - viewer.add_image( - result.result, - colormap=out_colormap, - name=f"pred_{image_id}_{model_name}", - opacity=0.8, - ) - - if result.instance_labels is not None: - labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + if result.instance_labels is not None: + labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - stats = result.stats + stats = result.stats - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if self.worker_config.compute_stats and stats is not None: + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + self.log.print_and_log( + f"Number of instances : {stats.number_objects}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) + except Exception as e: + self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index de54b345..97915efc 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -982,7 +982,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e) + logger.error(e, exc_info=True) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1150,7 +1150,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e) + logger.error(e, exc_info=True) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index a803dfd7..a1a167a4 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -404,7 +404,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e) + logger.error(e, exc_info=True) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index ab3dba39..7b05d65c 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -11,12 +11,11 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP -from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet -from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR -from napari_cellseg3d.code_models.models import ( - model_TRAILMAP_MS as TRAILMAP_MS, -) -from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ +from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ +from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ +from napari_cellseg3d.code_models.models.model_VNet import VNet_ + from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -25,16 +24,15 @@ # TODO(cyril) add JSON load/save MODEL_LIST = { - "SegResNet": SegResNet, - "VNet": VNet, + "SegResNet": SegResNet_, + "VNet": VNet_, # "TRAILMAP": TRAILMAP, - "TRAILMAP_MS": TRAILMAP_MS, - "SwinUNetR": SwinUNetR, + "TRAILMAP_MS": TRAILMAP_MS_, + "SwinUNetR": SwinUNETR_, # "test" : DO NOT USE, reserved for testing } - -WEIGHTS_DIR = str( +PRETRAINED_WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") ) @@ -70,8 +68,11 @@ class ReviewSession: @dataclass class ModelInfo: - """Dataclass recording model info : - - name (str): name of the model""" + """Dataclass recording model info + Args: + name (str): name of the model + model_input_size (Optional[List[int]]): input size of the model + """ name: str = next(iter(MODEL_LIST)) model_input_size: Optional[List[int]] = None @@ -95,7 +96,7 @@ def get_model_name_list(): @dataclass class WeightsInfo: - path: str = WEIGHTS_DIR + path: str = PRETRAINED_WEIGHTS_DIR custom: bool = False use_pretrained: Optional[bool] = False @@ -124,6 +125,14 @@ class InstanceSegConfig: @dataclass class PostProcessConfig: + """Class to record params for post processing + + Args: + zoom (Zoom): zoom config + thresholding (Thresholding): thresholding config + instance (InstanceSegConfig): instance segmentation config + """ + zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() instance: InstanceSegConfig = InstanceSegConfig() @@ -144,7 +153,15 @@ def is_enabled(self): @dataclass class InfererConfig: - """Class to record params for Inferer plugin""" + """Class to record params for Inferer plugin + + Args: + model_info (ModelInfo): model info + show_results (bool): show results in napari + show_results_count (int): number of results to show + show_original (bool): show original image in napari + anisotropy_resolution (List[int]): anisotropy resolution + """ model_info: ModelInfo = None show_results: bool = False @@ -155,7 +172,21 @@ class InfererConfig: @dataclass class InferenceWorkerConfig: - """Class to record configuration for Inference job""" + """Class to record configuration for Inference job + + Args: + device (str): device to use for inference + model_info (ModelInfo): model info + weights_config (WeightsInfo): weights info + results_path (str): path to save results + filetype (str): filetype to save results + keep_on_cpu (bool): keep results on cpu + compute_stats (bool): compute stats + post_process_config (PostProcessConfig): post processing config + sliding_window_config (SlidingWindowConfig): sliding window config + images_filepaths (str): path to images to infer + layer (napari.layers.Layer): napari layer to infer on + """ device: str = "cpu" model_info: ModelInfo = ModelInfo() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 251fbb06..6417f073 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -297,6 +297,22 @@ def warn(self, warning): finally: self.lock.release() + def error(self, error, msg=None): + """Show exception and message from another thread""" + self.lock.acquire() + try: + logger.error(error, exc_info=True) + if msg is not None: + self.print_and_log(f"{msg} : {error}", printing=False) + else: + self.print_and_log( + f"Excepetion caught in another thread : {error}", + printing=False, + ) + raise error + finally: + self.lock.release() + ############## # UI elements @@ -1201,7 +1217,7 @@ def open_folder_dialog( logger.info(f"Default : {default_path}") filenames = QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + widget, "Open directory", default_path + "/.." ) return filenames diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 4b04d536..bc6203be 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,10 +2,8 @@ import warnings from datetime import datetime from pathlib import Path - import numpy as np - -# from dask import delayed +from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -38,6 +36,18 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] +# class TiffFileReader(ImageReader): +# def __init__(self): +# super().__init__() +# +# def verify_suffix(self, filename): +# if filename == "tif": +# return True +# def read(self, data, **kwargs): +# return tfl_imread(data) +# +# def get_data(self, data): +# return data, {} def normalize_x(image): """Normalizes the values of an image array to be between [-1;1] rather than [0;255] @@ -122,8 +132,6 @@ def dice_coeff(y_true, y_pred): def resize(image, zoom_factors): - from monai.transforms import Zoom - isotropic_image = Zoom( zoom_factors, keep_size=False, diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b8810301..59ae05c1 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -44,10 +44,20 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -57,14 +67,15 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", + "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = resize(prediction, zoom)\n", + "prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -85,7 +96,7 @@ { "data": { "text/plain": [ - "0.5817600487210719" + "0.7538125057831502" ] }, "execution_count": 4, @@ -96,9 +107,15 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", + "semantic_gt = to_semantic(gt_labels_resized.copy())\n", + "semantic_pred = to_semantic(prediction_resized.copy())\n", + "\n", + "viewer.add_image(semantic_gt, colormap='bop blue')\n", + "viewer.add_image(semantic_pred, colormap='red')\n", + "\n", "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", + " semantic_gt,\n", + " prediction_resized\n", ")" ] }, @@ -171,7 +188,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -198,24 +215,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" + "2023-03-24 14:23:13,590 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", + "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", + "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" ] }, { @@ -228,15 +245,15 @@ { "data": { "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" + "(63,\n", + " 45,\n", + " 16,\n", + " 16,\n", + " 0.819027731148306,\n", + " 0.8401649108992161,\n", + " 0.83609908334452,\n", + " 0.8066092803671974,\n", + " 0.98)" ] }, "execution_count": 9, @@ -262,24 +279,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" + "2023-03-24 14:23:13,732 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", + "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", + "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -292,15 +309,15 @@ { "data": { "text/plain": [ - "(68,\n", - " 43,\n", + "(77,\n", + " 34,\n", " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" + " 9,\n", + " 0.728461197681457,\n", + " 0.8885669859686413,\n", + " 0.8950588507577087,\n", + " 0.7472814623489069,\n", + " 0.878614359974009)" ] }, "execution_count": 10, @@ -338,7 +355,7 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", + "voronoi = voronoi_otsu(prediction_resized, 0.6, outline_sigma=0.7)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", @@ -414,24 +431,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" + "2023-03-24 14:23:14,241 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", + "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", + "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" ] }, { @@ -444,15 +461,15 @@ { "data": { "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" + "(102,\n", + " 8,\n", + " 14,\n", + " 16,\n", + " 0.708505702558253,\n", + " 0.8832633585884945,\n", + " 0.9759871495093808,\n", + " 0.6670483272595948,\n", + " 0.8653680990771155)" ] }, "execution_count": 15, diff --git a/requirements.txt b/requirements.txt index 3189e9c4..3ca0e56d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ black coverage +imageio-ffmpeg>=0.4.5 isort itk pytest @@ -15,13 +16,12 @@ QtPy opencv-python>=4.5.5 pre-commit pyclesperanto-prototype>=0.22.0 -pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 +ruff tifffile>=2022.2.9 -imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,einops]>=1.0.1 +monai[nibabel,einops,tifffile]>=1.0.1 pillow scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index 2420dd1c..f3294b60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=1.0.1 + monai[nibabel,einops,tifffile]>=1.0.1 itk tqdm nibabel From 556ff672833dc016ede0298a5394c0e9279b505e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 29 Mar 2023 09:55:58 +0200 Subject: [PATCH 076/577] Added LR scheduler in training - Added ReduceLROnPlateau with params in training - Updated training guide - Minor UI attribute refactor - black --- docs/res/code/plugin_model_training.rst | 1 - docs/res/guides/training_module_guide.rst | 2 + napari_cellseg3d/_tests/fixtures.py | 3 + .../_tests/test_plugin_inference.py | 2 +- .../code_models/model_framework.py | 2 +- .../code_models/model_instance_seg.py | 8 +-- napari_cellseg3d/code_models/model_workers.py | 11 ++++ napari_cellseg3d/code_plugins/plugin_base.py | 2 +- .../code_plugins/plugin_convert.py | 4 +- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- .../code_plugins/plugin_model_inference.py | 4 +- .../code_plugins/plugin_model_training.py | 62 ++++++++++++------- .../code_plugins/plugin_utilities.py | 2 +- napari_cellseg3d/config.py | 2 + napari_cellseg3d/interface.py | 43 +++++++------ 15 files changed, 93 insertions(+), 57 deletions(-) diff --git a/docs/res/code/plugin_model_training.rst b/docs/res/code/plugin_model_training.rst index 870dfd14..dc1271fc 100644 --- a/docs/res/code/plugin_model_training.rst +++ b/docs/res/code/plugin_model_training.rst @@ -18,6 +18,5 @@ Methods Attributes ********************* - .. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer :members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index fb8992d2..05ce69be 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -74,6 +74,8 @@ The training module is comprised of several tabs. * The **batch size** (larger means quicker training and possibly better performance but increased memory usage) * The **number of epochs** (a possibility is to start with 60 epochs, and decrease or increase depending on performance.) * The **epoch interval** for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.) +* The **schedular patience**, which is the amount of epoch at a plateau that is waited for until the learning rate is reduced +* The **scheduler factor**, which is the factor by which to reduce the learning rate once a plateau is reached * Whether to use deterministic training, and the seed to use. .. note:: diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index b40a77d3..bd6b0ac7 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -14,3 +14,6 @@ def print_and_log(self, text, printing=None): def warn(self, warning): warnings.warn(warning) + + def error(self, e): + raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..66c50fba 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -38,4 +38,4 @@ def test_inference(make_napari_viewer, qtbot): # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: # blocker.connect(widget.worker.errored) - # assert len(viewer.layers) == 2 + #### assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index b3121cf4..1e6b934a 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -80,7 +80,7 @@ def __init__( # ) self.model_choice = ui.DropdownMenu( - sorted(self.available_models.keys()), label="Model name" + sorted(self.available_models.keys()), text_label="Model name" ) self.weights_filewidget = ui.FilePathWidget( diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index cc362eac..4f702a4b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -74,7 +74,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label="", parent=None), + ui.DoubleIncrementCounter(text_label="", parent=None), ) self.counters.append(getattr(self, widget)) @@ -393,13 +393,13 @@ def __init__(self, widget_parent=None): widget_parent=widget_parent, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" self.sliders[0].setValue(50) - self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].label.setText("Seed probability threshold") self.sliders[1].tooltips = "Probability threshold for seeding" self.sliders[1].setValue(90) @@ -439,7 +439,7 @@ def __init__(self, widget_parent=None): widget_parent=widget_parent, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 25a9da19..71abaaed 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -69,6 +69,7 @@ PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) +VERBOSE_SCHEDULER = True logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") @@ -1292,6 +1293,13 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + mode="min", + factor=self.config.scheduler_factor, + patience=self.config.scheduler_patience, + verbose=VERBOSE_SCHEDULER, + ) dice_metric = DiceMetric(include_background=True, reduction="mean") best_metric = -1 @@ -1384,6 +1392,9 @@ def train(self): epoch_loss_values.append(epoch_loss) self.log(f"Epoch: {epoch + 1}, Average loss: {epoch_loss:.4f}") + self.log("Updating scheduler...") + scheduler.step(epoch_loss) + checkpoint_output = [] if ( diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 7c5fbaa5..02c9fbff 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -100,7 +100,7 @@ def __init__( ) self.filetype_choice = ui.DropdownMenu( - [".tif", ".tiff"], label="File format" + [".tif", ".tiff"], text_label="File format" ) ######## qInstallMessageHandler(ui.handle_adjust_errors_wrapper(self)) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index ed1a43df..547e0233 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -209,7 +209,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): lower=1, upper=100000, default=10, - label="Remove all smaller than (pxs):", + text_label="Remove all smaller than (pxs):", ) self.results_path = Path.home() / Path("cellseg3d/small_removed") @@ -469,7 +469,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): upper=100000.0, step=0.5, default=10.0, - label="Remove all smaller than (value):", + text_label="Remove all smaller than (value):", ) self.results_path = Path.home() / Path("cellseg3d/threshold") diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 07885236..789be9e5 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -536,7 +536,7 @@ def set_slice( # container_widget.extend(sliders) ui.add_widgets( container_widget.layout, - [ui.combine_blocks(s, s.text_label) for s in sliders], + [ui.combine_blocks(s, s.label) for s in sliders], ) # vw.window.add_dock_widget([spinbox, container_widget], area="right") wdgts = vw.window.add_dock_widget( diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 33b18eda..060964df 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -106,7 +106,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ###################### # TODO : better way to handle SegResNet size reqs ? self.model_input_size = ui.IntIncrementCounter( - lower=1, upper=1024, default=128, label="\nModel input size" + lower=1, upper=1024, default=128, text_label="\nModel input size" ) self.model_choice.currentIndexChanged.connect( self._toggle_display_model_input_size @@ -158,7 +158,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): # ) self.window_size_choice = ui.DropdownMenu( - sizes_window, label="Window size" + sizes_window, text_label="Window size" ) self.window_size_choice.setCurrentIndex(3) # set to 64 by default diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 97915efc..7bd1b0bf 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -43,6 +43,8 @@ class Trainer(ModelFramework, metaclass=ui.QWidgetSingleton): Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during training through validation.""" + default_config = config.TrainingWorkerConfig() + def __init__( self, viewer: "napari.viewer.Viewer", @@ -165,14 +167,13 @@ def __init__( ################################ # interface - default = config.TrainingWorkerConfig() self.zip_choice = ui.CheckBox("Compress results") self.validation_percent_choice = ui.Slider( lower=10, upper=90, - default=default.validation_percent * 100, + default=self.default_config.validation_percent * 100, step=5, parent=self, ) @@ -180,12 +181,12 @@ def __init__( self.epoch_choice = ui.IntIncrementCounter( lower=2, upper=200, - default=default.max_epochs, - label="Number of epochs : ", + default=self.default_config.max_epochs, + text_label="Number of epochs : ", ) self.loss_choice = ui.DropdownMenu( - sorted(self.loss_dict.keys()), label="Loss function" + sorted(self.loss_dict.keys()), text_label="Loss function" ) self.lbl_loss_choice = self.loss_choice.label self.loss_choice.setCurrentIndex(0) @@ -193,7 +194,7 @@ def __init__( self.sample_choice_slider = ui.Slider( lower=2, upper=50, - default=default.num_samples, + default=self.default_config.num_samples, text_label="Number of patches per image : ", ) @@ -202,13 +203,13 @@ def __init__( self.batch_choice = ui.Slider( lower=1, upper=10, - default=default.batch_size, + default=self.default_config.batch_size, text_label="Batch size : ", ) self.val_interval_choice = ui.IntIncrementCounter( - default=default.validation_interval, - label="Validation interval : ", + default=self.default_config.validation_interval, + text_label="Validation interval : ", ) self.epoch_choice.valueChanged.connect(self._update_validation_choice) @@ -225,12 +226,24 @@ def __init__( ] self.learning_rate_choice = ui.DropdownMenu( - learning_rate_vals, label="Learning rate" + learning_rate_vals, text_label="Learning rate" ) self.lbl_learning_rate_choice = self.learning_rate_choice.label self.learning_rate_choice.setCurrentIndex(1) + self.scheduler_patience_choice = ui.IntIncrementCounter( + 1, + 99, + default=self.default_config.scheduler_patience, + text_label="Scheduler patience", + ) + self.scheduler_factor_choice = ui.Slider( + divide_factor=100, + default=self.default_config.scheduler_factor * 100, + text_label="Scheduler factor :", + ) + self.augment_choice = ui.CheckBox("Augment data") self.close_buttons = [ @@ -265,7 +278,8 @@ def __init__( "Deterministic training", func=self._toggle_deterministic_param ) self.box_seed = ui.IntIncrementCounter( - upper=10000000, default=default.deterministic_config.seed + upper=10000000, + default=self.default_config.deterministic_config.seed, ) self.lbl_seed = ui.make_label("Seed", self) self.container_seed = ui.combine_blocks( @@ -306,6 +320,12 @@ def set_tooltips(): self.learning_rate_choice.setToolTip( "The learning rate to use in the optimizer. \nUse a lower value if you're using pre-trained weights" ) + self.scheduler_factor_choice.setToolTip( + "The factor by which to reduce the learning rate once the loss reaches a plateau" + ) + self.scheduler_patience_choice.setToolTip( + "The amount of epochs to wait for before reducing the learning rate" + ) self.augment_choice.setToolTip( "Check this to enable data augmentation, which will randomly deform, flip and shift the intensity in images" " to provide a more general dataset. \nUse this if you're extracting more than 10 samples per image" @@ -629,26 +649,20 @@ def _build(self): "Training parameters", r=1, b=5, t=11 ) - spacing = 20 - ui.add_widgets( train_param_group_l, [ self.batch_choice.container, # batch size - ui.combine_blocks( - self.learning_rate_choice, - self.lbl_learning_rate_choice, - min_spacing=spacing, - horizontal=False, - l=5, - t=5, - r=5, - b=5, - ), # learning rate + self.lbl_learning_rate_choice, + self.learning_rate_choice, self.epoch_choice.label, # epochs self.epoch_choice, self.val_interval_choice.label, self.val_interval_choice, # validation interval + self.scheduler_patience_choice.label, + self.scheduler_patience_choice, + self.scheduler_factor_choice.label, + self.scheduler_factor_choice.container, ], None, ) @@ -830,6 +844,8 @@ def start(self): max_epochs=self.epoch_choice.value(), loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.value(), validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index c962717e..e141bfe5 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -41,7 +41,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): # self.small = RemoveSmallUtils(self._viewer) self.utils_choice = ui.DropdownMenu( - UTILITIES_WIDGETS.keys(), label="Utilities" + UTILITIES_WIDGETS.keys(), text_label="Utilities" ) self._build() diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 7b05d65c..e22f1b62 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -233,6 +233,8 @@ class TrainingWorkerConfig: max_epochs: int = 5 loss_function: callable = None learning_rate: np.float64 = 1e-3 + scheduler_patience: int = 10 + scheduler_factor: float = 0.5 validation_interval: int = 2 batch_size: int = 1 results_path_folder: str = str(Path.home() / Path("cellseg3d/training")) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 6417f073..484d137d 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -410,21 +410,21 @@ def __init__( self, entries: Optional[list] = None, parent: Optional[QWidget] = None, - label: Optional[str] = None, + text_label: Optional[str] = None, fixed: Optional[bool] = True, ): """Args: entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None - label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well + text_label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True. """ super().__init__(parent) self.label = None if entries is not None: self.addItems(entries) - if label is not None: - self.label = QLabel(label) + if text_label is not None: + self.label = QLabel(text_label) if fixed: self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) @@ -475,9 +475,10 @@ def __init__( self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - self.text_label = None + self.label = None self.container = ContainerWidget( - # parent=self.parent + # parent=self.parent, + b=0, ) self._divide_factor = divide_factor @@ -500,7 +501,7 @@ def __init__( ) if text_label is not None: - self.text_label = make_label(text_label, parent=self) + self.label = make_label(text_label, parent=self) if default < lower: self._warn_outside_bounds(default) @@ -519,14 +520,14 @@ def __init__( def set_visibility(self, visible: bool): self.container.setVisible(visible) self.setVisible(visible) - self.text_label.setVisible(visible) + self.label.setVisible(visible) def _build_container(self): - if self.text_label is not None: + if self.label is not None: add_widgets( self.container.layout, [ - self.text_label, + self.label, combine_blocks(self._value_label, self, b=0), ], ) @@ -570,8 +571,8 @@ def tooltips(self, tooltip: str): self.setToolTip(tooltip) self._value_label.setToolTip(tooltip) - if self.text_label is not None: - self.text_label.setToolTip(tooltip) + if self.label is not None: + self.label.setToolTip(tooltip) @property def slider_value(self): @@ -741,7 +742,9 @@ def __init__( self.image = None self.layer_type = layer_type - self.layer_list = DropdownMenu(parent=self, label=name, fixed=False) + self.layer_list = DropdownMenu( + parent=self, text_label=name, fixed=False + ) # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? self._viewer.layers.events.inserted.connect(partial(self._add_layer)) @@ -1046,7 +1049,7 @@ def __init__( step: Optional[float] = 1.0, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[float]): minimum value, defaults to 0 @@ -1055,7 +1058,7 @@ def __init__( step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed - label (Optional[str]): if provided, creates a label with the chosen title to use with the counter + text_label (Optional[str]): if provided, creates a label with the chosen title to use with the counter """ super().__init__(parent) @@ -1063,8 +1066,8 @@ def __init__( self.layout = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) self.valueChanged.connect(self._update_step) def _update_step(self): @@ -1124,7 +1127,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[int]): minimum value, defaults to 0 @@ -1140,8 +1143,8 @@ def __init__( self.label = None self.container = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) @property def tooltips(self): From 55f46f892a6aca6b5b0cf354f52716cd2ec2a380 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 31 Mar 2023 15:45:00 +0200 Subject: [PATCH 077/577] Update assess_instance.ipynb --- notebooks/assess_instance.ipynb | 162 ++++++++++++++++++++------------ 1 file changed, 101 insertions(+), 61 deletions(-) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 59ae05c1..0dec4543 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -44,20 +44,10 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -67,15 +57,16 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "# prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "# prediction_resized = resize(prediction, zoom)\n", - "prediction_resized = prediction # for trailmap\n", + "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -96,7 +87,7 @@ { "data": { "text/plain": [ - "0.7538125057831502" + "0.8592223181276479" ] }, "execution_count": 4, @@ -188,7 +179,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -215,24 +206,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,590 - Mapping labels...\n" + "2023-03-31 15:37:19,775 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3699.66it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", - "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", - "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" + "2023-03-31 15:37:19,812 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,815 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-31 15:37:19,816 - Percent of fused neurons found: 36.80%\n", + "2023-03-31 15:37:19,817 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -245,15 +236,15 @@ { "data": { "text/plain": [ - "(63,\n", - " 45,\n", - " 16,\n", - " 16,\n", - " 0.819027731148306,\n", - " 0.8401649108992161,\n", - " 0.83609908334452,\n", - " 0.8066092803671974,\n", - " 0.98)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -279,24 +270,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,732 - Mapping labels...\n" + "2023-03-31 15:37:19,919 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3992.79it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", - "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", - "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" + "2023-03-31 15:37:19,949 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,952 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-31 15:37:19,953 - Percent of fused neurons found: 34.40%\n", + "2023-03-31 15:37:19,953 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -309,15 +300,15 @@ { "data": { "text/plain": [ - "(77,\n", - " 34,\n", + "(68,\n", + " 43,\n", " 13,\n", - " 9,\n", - " 0.728461197681457,\n", - " 0.8885669859686413,\n", - " 0.8950588507577087,\n", - " 0.7472814623489069,\n", - " 0.878614359974009)" + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -343,6 +334,40 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:37:21,076 - build program: kernel 'gaussian_blur_separable_3d' was part of a lengthy source build resulting from a binary cache miss (0.88 s)\n", + "2023-03-31 15:37:21,514 - build program: kernel 'copy_3d' was part of a lengthy source build resulting from a binary cache miss (0.42 s)\n", + "2023-03-31 15:37:22,021 - build program: kernel 'detect_maxima_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:22,642 - build program: kernel 'minimum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.59 s)\n", + "2023-03-31 15:37:23,117 - build program: kernel 'minimum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:23,651 - build program: kernel 'minimum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,188 - build program: kernel 'maximum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,801 - build program: kernel 'maximum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.60 s)\n", + "2023-03-31 15:37:25,263 - build program: kernel 'maximum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:25,766 - build program: kernel 'histogram_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:26,256 - build program: kernel 'sum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:26,699 - build program: kernel 'greater_constant_3d' was part of a lengthy source build resulting from a binary cache miss (0.43 s)\n", + "2023-03-31 15:37:27,158 - build program: kernel 'binary_and_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:27,635 - build program: kernel 'add_image_and_scalar_3d' was part of a lengthy source build resulting from a binary cache miss (0.47 s)\n", + "2023-03-31 15:37:28,128 - build program: kernel 'set_nonzero_pixels_to_pixelindex' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:28,580 - build program: kernel 'set_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:29,076 - build program: kernel 'nonzero_minimum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:29,551 - build program: kernel 'set_2d' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:30,035 - build program: kernel 'flag_existing_labels' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:30,544 - build program: kernel 'set_column_2d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:31,033 - build program: kernel 'sum_reduction_x' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:31,572 - build program: kernel 'block_enumerate' was part of a lengthy source build resulting from a binary cache miss (0.53 s)\n", + "2023-03-31 15:37:32,094 - build program: kernel 'replace_intensities' was part of a lengthy source build resulting from a binary cache miss (0.51 s)\n", + "2023-03-31 15:37:32,685 - build program: kernel 'add_images_weighted_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:33,256 - build program: kernel 'onlyzero_overwrite_maximum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.56 s)\n", + "2023-03-31 15:37:33,845 - build program: kernel 'onlyzero_overwrite_maximum_diamond_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:34,369 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:34,888 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n" + ] + }, { "data": { "text/plain": [ @@ -431,24 +456,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,241 - Mapping labels...\n" + "2023-03-31 15:37:36,854 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 611.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", - "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", - "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" + "2023-03-31 15:37:37,087 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:37,098 - Percent of non-fused neurons found: 87.20%\n", + "2023-03-31 15:37:37,104 - Percent of fused neurons found: 1.60%\n", + "2023-03-31 15:37:37,114 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -461,15 +486,15 @@ { "data": { "text/plain": [ - "(102,\n", + "(109,\n", + " 2,\n", + " 13,\n", " 8,\n", - " 14,\n", - " 16,\n", - " 0.708505702558253,\n", - " 0.8832633585884945,\n", - " 0.9759871495093808,\n", - " 0.6670483272595948,\n", - " 0.8653680990771155)" + " 0.8285521200005869,\n", + " 0.8809251900364068,\n", + " 0.9838709677419355,\n", + " 0.782258064516129,\n", + " 1.0)" ] }, "execution_count": 15, @@ -490,10 +515,25 @@ "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:40:34,683 - No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'\n" + ] + } + ], "source": [ "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -512,7 +552,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" } }, "nbformat": 4, From cf96be761b482d8fc39b2161ad6e4d1434113ab5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 11:09:30 +0200 Subject: [PATCH 078/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df43b4fa..df67a187 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png +notebooks/instance_test.ipynb *.prof #include test data From 3ab688a38c9d2aa13efdb368fbed14cd7dacf7e8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 14:27:21 +0200 Subject: [PATCH 079/577] Started adding WNet --- napari_cellseg3d/code_models/model_workers.py | 4 +- .../code_models/models/model_SwinUNetR.py | 29 +- .../code_models/models/model_TRAILMAP_MS.py | 15 +- .../code_models/models/model_WNet.py | 27 ++ .../pretrained/pretrained_model_urls.json | 1 + .../code_models/models/wnet/__init__.py | 0 .../code_models/models/wnet/crf.py | 112 ++++++ .../code_models/models/wnet/model.py | 189 ++++++++++ .../code_models/models/wnet/soft_Ncuts.py | 352 ++++++++++++++++++ napari_cellseg3d/config.py | 22 ++ 10 files changed, 739 insertions(+), 12 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/model_WNet.py create mode 100644 napari_cellseg3d/code_models/models/wnet/__init__.py create mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py create mode 100644 napari_cellseg3d/code_models/models/wnet/model.py create mode 100644 napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 71abaaed..7d09e288 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -761,7 +761,9 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=[128, 128, 128], + input_img_size=dims, + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) # try: model = model.to(self.config.device) diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index fe4d380c..f38409b8 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,7 @@ from monai.networks.nets import SwinUNETR +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER class SwinUNETR_(SwinUNETR): @@ -6,14 +9,24 @@ class SwinUNETR_(SwinUNETR): weights_file = "Swin64_best_metric.pth" def __init__(self, input_img_size, use_checkpoint=True, **kwargs): - super().__init__( - input_img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - **kwargs - ) + try: + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs, + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index e3ca00a6..1123173a 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,4 +1,7 @@ from napari_cellseg3d.code_models.models.unet.model import UNet3D +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER class TRAILMAP_MS_(UNet3D): @@ -8,9 +11,15 @@ class TRAILMAP_MS_(UNet3D): # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) def __init__(self, in_channels=1, out_channels=1, **kwargs): - super().__init__( - in_channels=in_channels, out_channels=out_channels, **kwargs - ) + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py new file mode 100644 index 00000000..63a91b10 --- /dev/null +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -0,0 +1,27 @@ +from napari_cellseg3d.code_models.models.wnet.model import WNet + + +class WNet_(WNet): + use_default_training = False + weights_file = "wnet.pth" + + def __init__( + self, + in_channels=1, + out_channels=1, + num_classes=2, + device="cpu", + **kwargs + ): + super().__init__( + device=device, + in_channels=in_channels, + out_channels=out_channels, + num_classes=num_classes, + ) + + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + # dec = self.forward_decoder(enc) + return enc diff --git a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json index cd0782fb..cde5e332 100644 --- a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json @@ -3,5 +3,6 @@ "SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet.tar.gz", "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet.tar.gz", "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz", + "WNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet.tar.gz", "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" } diff --git a/napari_cellseg3d/code_models/models/wnet/__init__.py b/napari_cellseg3d/code_models/models/wnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py new file mode 100644 index 00000000..ca11fba2 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -0,0 +1,112 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" + +import numpy as np +import pydensecrf.densecrf as dcrf +from pydensecrf.utils import ( + unary_from_softmax, + create_pairwise_gaussian, + create_pairwise_bilateral, +) + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py new file mode 100644 index 00000000..585ea0dd --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -0,0 +1,189 @@ +""" +Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. +The model performs unsupervised segmentation of 3D images. +""" + +import torch +import torch.nn as nn + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", +] + + +class WNet(nn.Module): + """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. + The model performs unsupervised segmentation of 3D images. + It first encodes the input image into a latent space using the U-Net UEncoder, then decodes it back to the original image using the U-Net UDecoder. + """ + + def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + super(WNet, self).__init__() + self.device = device + self.encoder = UNet(device, in_channels, num_classes, encoder=True) + self.decoder = UNet(device, num_classes, out_channels, encoder=False) + + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + dec = self.forward_decoder(enc) + return enc, dec + + def forward_encoder(self, x): + """Forward pass of the encoder part of the W-Net model.""" + enc = self.encoder(x) + return enc + + def forward_decoder(self, enc): + """Forward pass of the decoder part of the W-Net model.""" + dec = self.decoder(enc) + return dec + + +class UNet(nn.Module): + """Half of the W-Net model, based on the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels, encoder=True): + super(UNet, self).__init__() + self.device = device + self.in_b = InBlock(device, in_channels, 64) + self.conv1 = Block(device, 64, 128) + self.conv2 = Block(device, 128, 256) + self.conv3 = Block(device, 256, 512) + self.bot = Block(device, 512, 1024) + self.deconv1 = Block(device, 1024, 512) + self.deconv2 = Block(device, 512, 256) + self.deconv3 = Block(device, 256, 128) + self.out_b = OutBlock(device, 128, out_channels) + + self.sm = nn.Softmax(dim=1).to(device) + self.encoder = encoder + + def forward(self, x): + """Forward pass of the U-Net model.""" + in_b = self.in_b(x.to(self.device)) + c1 = self.conv1(nn.MaxPool3d(2)(in_b)) + c2 = self.conv2(nn.MaxPool3d(2)(c1)) + c3 = self.conv3(nn.MaxPool3d(2)(c2)) + x = self.bot(nn.MaxPool3d(2)(c3)) + x = self.deconv1( + torch.cat( + [ + c3, + nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.deconv2( + torch.cat( + [ + c2, + nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.deconv3( + torch.cat( + [ + c1, + nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.out_b( + torch.cat( + [ + in_b, + nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + if self.encoder: + x = self.sm(x) + return x + + +class InBlock(nn.Module): + """Input block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(InBlock, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the input block.""" + return self.module(x.to(self.device)) + + +class Block(nn.Module): + """Basic block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(Block, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), + nn.Conv3d(in_channels, out_channels, 1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), + nn.Conv3d(out_channels, out_channels, 1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the basic block.""" + return self.module(x.to(self.device)) + + +class OutBlock(nn.Module): + """Output block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(OutBlock, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, 64, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(64, device=device), + nn.Conv3d(64, 64, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(64, device=device), + nn.Conv3d(64, out_channels, 1, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the output block.""" + return self.module(x.to(self.device)) diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py new file mode 100644 index 00000000..6a625355 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -0,0 +1,352 @@ +""" +Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. +The implementation was adapted and approximated to reduce computational and memory cost. +This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from scipy.stats import norm + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", + "Jianbo Shi", + "Jitendra Malik", + "Frank Odom", +] + + +class SoftNCutsLoss(nn.Module): + """Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. + + Args: + data_shape (H, W, D): shape of the images as a tuple. + o_i (scalar): scale of the gaussian kernel of pixels brightness. + o_x (scalar): scale of the gaussian kernel of pixels spacial distance. + radius (scalar): radius of pixels for which we compute the weights + """ + + def __init__(self, data_shape, device, o_i, o_x, radius=None): + super(SoftNCutsLoss, self).__init__() + self.o_i = o_i + self.o_x = o_x + self.radius = radius + self.H = data_shape[0] + self.W = data_shape[1] + self.D = data_shape[2] + self.device = device + + if self.radius is None: + self.radius = min( + max(5, math.ceil(min(self.H, self.W, self.D) / 20)), + self.H, + self.W, + self.D, + ) + + # self.distances, self.indexes = self.get_distances() + + """ + + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration + distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) + distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) + distances_D = torch.tensor(range(self.D)).expand(self.D, self.D) # (D, D) + + # Compute in cuda if possible + if torch.cuda.is_available(): + distances_H = distances_H.cuda() + distances_W = distances_W.cuda() + distances_D = distances_D.cuda() + + distances_H = torch.abs(torch.subtract(distances_H, distances_H.T)) # (H, H) + distances_W = torch.abs(torch.subtract(distances_W, distances_W.T)) # (W, W) + distances_D = torch.abs(torch.subtract(distances_D, distances_D.T)) # (D, D) + + distances_H = distances_H.view(self.H, 1, 1, self.H, 1, 1).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (H, 1, 1, H, 1, 1) -> (H, W, D, H, W, D) + distances_W = distances_W.view(1, self.W, 1, 1, self.W, 1).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (1, W, 1, 1, W, 1) -> (H, W, D, H, W, D) + distances_D = distances_D.view(1, 1, self.D, 1, 1, self.D).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (1, 1, D, 1, 1, D) -> (H, W, D, H, W, D) + + mask_H = torch.le(distances_H, self.radius).bool() # (H, W, D, H, W, D) + mask_W = torch.le(distances_W, self.radius).bool() # (H, W, D, H, W, D) + mask_D = torch.le(distances_D, self.radius).bool() # (H, W, D, H, W, D) + + distances_H = (distances_H * mask_H) # (H, W, D, H, W, D) + distances_W = (distances_W * mask_W) # (H, W, D, H, W, D) + distances_D = (distances_D * mask_D) # (H, W, D, H, W, D) + + mask_H =mask_H.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + mask_W =mask_W.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + mask_D =mask_D.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + + distances_H = distances_H.pow(2) # (H, W, D, H, W, D) + distances_W = distances_W.pow(2) # (H, W, D, H, W, D) + distances_D = distances_D.pow(2) # (H, W, D, H, W, D) + + squared_distances = torch.add( + torch.add(distances_H, distances_W), + distances_D, + ) # (H, W, D, H, W, D) + + squared_distances = squared_distances.flatten(0, 2).flatten( + 1, 3 + ) # (H*W*D, H*W*D) + + # Mask to only keep the weights for the pixels in the radius + self.mask = torch.le(squared_distances, self.radius**2).bool() # (H*W*D, H*W*D) + + # Add all masks to get the final mask + self.mask = self.mask.logical_and(mask_H).logical_and(mask_W).logical_and(mask_D) # (H*W*D, H*W*D) + + W_X = torch.exp( + torch.neg(torch.div(squared_distances, self.o_x)) + ) # (H*W*D, H*W*D) + + self.W_X = torch.mul(W_X, self.mask) # (H*W*D, H*W*D) + """ + + def forward(self, labels, inputs): + """Forward pass of the Soft N-Cuts loss. + + Args: + labels (torch.Tensor): Tensor of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + The Soft N-Cuts loss of shape (N,). + """ + inputs.shape[0] + inputs.shape[1] + K = labels.shape[1] + + labels.to(self.device) + inputs.to(self.device) + + loss = 0 + + kernel = self.gaussian_kernel(self.radius, self.o_x).to(self.device) + + for k in range(K): + # Compute the average pixel value for this class, and the difference from each pixel + class_probs = labels[:, k].unsqueeze(1) + class_mean = torch.mean( + inputs * class_probs, dim=(2, 3, 4), keepdim=True + ) / torch.add( + torch.mean(class_probs, dim=(2, 3, 4), keepdim=True), 1e-5 + ) + diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1) + + # Weight the loss by the difference from the class average. + weights = torch.exp(diff.pow(2).mul(-1 / self.o_i**2)) + + numerator = torch.sum( + class_probs + * F.conv3d(class_probs * weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + denominator = torch.sum( + class_probs * F.conv3d(weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + loss += nn.L1Loss()( + numerator / torch.add(denominator, 1e-6), + torch.zeros_like(numerator), + ) + + return K - loss + + """ + for k in range(K): + Ak = labels[:, k, :, :, :] # (N, H, W, D) + flatted_Ak = Ak.view(N, -1) # (N, H*W*D) + + # Compute the numerator of the Soft N-Cuts loss for k + flatted_Ak_unsqueeze = flatted_Ak.unsqueeze(1) # (N, 1, H*W*D) + transposed_Ak = torch.transpose(flatted_Ak_unsqueeze, 1, 2) # (N, H*W*D, 1) + probs = torch.bmm(transposed_Ak, flatted_Ak_unsqueeze) # (N, H*W*D, H*W*D) + probs_unsqueeze_expanded = probs.unsqueeze(1) # (N, 1, H*W*D, H*W*D) + numerator_elements = torch.mul( + probs_unsqueeze_expanded, weights + ) # (N, C, H*W*D, H*W*D) + numerator = torch.sum(numerator_elements, dim=(2, 3)) # (N, C) + + # Compute the denominator of the Soft N-Cuts loss for k + expanded_flatted_Ak = flatted_Ak.expand( + -1, self.H * self.W * self.D + ) # (N, H*W*D, H*W*D) + e_f_Ak_unsqueeze_expanded = expanded_flatted_Ak.unsqueeze( + 1 + ) # (N, 1, H*W*D, H*W*D) + denominator_elements = torch.mul( + e_f_Ak_unsqueeze_expanded, weights + ) # (N, C, H*W*D, H*W*D) + denominator = torch.sum(denominator_elements, dim=(2, 3)) # (N, C) + + # Compute the Soft N-Cuts loss for k + division = torch.div(numerator, torch.add(denominator, 1e-8)) # (N, C) + loss = torch.sum(division, dim=1) # (N,) + losses.append(loss) + + loss = torch.sum(torch.stack(losses, dim=0), dim=0) # (N,) + + return torch.add(torch.neg(loss), K) + """ + + def gaussian_kernel(self, radius, sigma): + """Computes the Gaussian kernel. + + Args: + radius (int): The radius of the kernel. + sigma (float): The standard deviation of the Gaussian distribution. + + Returns: + The Gaussian kernel of shape (1, 1, 2*radius+1, 2*radius+1, 2*radius+1). + """ + x_2 = np.linspace(-radius, radius, 2 * radius + 1) ** 2 + dist = ( + np.sqrt( + x_2.reshape(-1, 1, 1) + + x_2.reshape(1, -1, 1) + + x_2.reshape(1, 1, -1) + ) + / sigma + ) + kernel = norm.pdf(dist) / norm.pdf(0) + kernel = torch.from_numpy(kernel.astype(np.float32)) + kernel = kernel.view( + (1, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]) + ) + + return kernel + + def get_distances(self): + """Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration. + + Returns: + distances (dict): for each pixel index, we get the distances to the pixels in a radius around it. + """ + + distances = dict() + indexes = np.array( + [ + (i, j, k) + for i in range(self.H) + for j in range(self.W) + for k in range(self.D) + ] + ) + + for i in indexes: + iTuple = (i[0], i[1], i[2]) + distances[iTuple] = dict() + + sliceD = indexes[ + i[0] * self.H + + i[1] * self.W + + max(0, i[2] - self.radius) : i[0] * self.H + + i[1] * self.W + + min(self.D, i[2] + self.radius) + ] + sliceW = indexes[ + i[0] * self.H + + max(0, i[1] - self.radius) * self.W + + i[2] : i[0] * self.H + + min(self.W, i[1] + self.radius) * self.W + + i[2] : self.D + ] + sliceH = indexes[ + max(0, i[0] - self.radius) * self.H + + i[1] * self.W + + i[2] : min(self.H, i[0] + self.radius) * self.H + + i[1] * self.W + + i[2] : self.D * self.W + ] + + for j in np.concatenate((sliceD, sliceW, sliceH)): + jTuple = (j[0], j[1], j[2]) + distance = np.linalg.norm(i - j) + if distance > self.radius: + continue + distance = math.exp(-(distance**2) / (self.o_x**2)) + + if jTuple not in distances: + distances[iTuple][jTuple] = distance + + return distances, indexes + + def get_weights(self, inputs): + """Computes the weights matrix for the Soft N-Cuts loss. + + Args: + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + list: List of the weights dict for each image in the batch. + """ + + """ + weights = [] + for n in range(inputs.shape[0]): + weightsChannel = [] + for c in range(inputs.shape[1]): + weightsImage = dict() + for i in self.indexes: + iTuple = (i[0], i[1], i[2]) + weightsImage[iTuple] = dict() + for j in self.indexes: + jTuple = (j[0], j[1], j[2]) + if iTuple in self.distances and jTuple in self.distances[i]: + brightness = ( + inputs[n][c][i[0]][i[1]][i[2]] + - inputs[n][c][j[0]][j[1]][j[2]] + ) ** 2 + brightness = math.exp(-brightness / self.o_i**2) + weightsImage[iTuple][jTuple] = ( + self.distances[iTuple][jTuple] * brightness + ) + + weightsChannel.append(weightsImage) + + weights.append(weightsChannel) + + return weights + + """ + + # Compute the brightness distance of the pixels + flatted_inputs = inputs.view( + inputs.shape[0], inputs.shape[1], -1 + ) # (N, C, H*W*D) + I_diff = torch.subtract( + flatted_inputs.unsqueeze(3), flatted_inputs.unsqueeze(2) + ) # (N, C, H*W*D, H*W*D) + masked_I_diff = torch.mul(I_diff, self.mask) # (N, C, H*W*D, H*W*D) + squared_I_diff = torch.pow(masked_I_diff, 2) # (N, C, H*W*D, H*W*D) + + W_I = torch.exp( + torch.neg(torch.div(squared_I_diff, self.o_i)) + ) # (N, C, H*W*D, H*W*D) + W_I = torch.mul(W_I, self.mask) # (N, C, H*W*D, H*W*D) + + # Get the spatial distance of the pixels + unsqueezed_W_X = self.W_X.view( + 1, 1, self.W_X.shape[0], self.W_X.shape[1] + ) # (1, 1, H*W*D, H*W*D) + + W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) + return W diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index e22f1b62..32683016 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -15,6 +15,7 @@ from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ +from napari_cellseg3d.code_models.models.model_WNet import WNet_ from napari_cellseg3d.utils import LOGGER @@ -29,6 +30,7 @@ # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS_, "SwinUNetR": SwinUNETR_, + "WNet": WNet_, # "test" : DO NOT USE, reserved for testing } @@ -72,10 +74,12 @@ class ModelInfo: Args: name (str): name of the model model_input_size (Optional[List[int]]): input size of the model + num_classes (int): number of classes for the model """ name: str = next(iter(MODEL_LIST)) model_input_size: Optional[List[int]] = None + num_classes: int = 2 def get_model(self): try: @@ -243,3 +247,21 @@ class TrainingWorkerConfig: sample_size: List[int] = None do_augmentation: bool = True deterministic_config: DeterministicConfig = DeterministicConfig() + + +################ +# CRF config for WNet +################ + + +@dataclass +class WNetCRFConfig: + "Class to store parameters of WNet CRF post processing" + + # CRF + sa = 10 # 50 + sb = 10 + sg = 1 + w1 = 10 # 50 + w2 = 10 + n_iter = 5 From 90f4fca750c63f3210cb47a2c252d3777be1a2b0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 20 Apr 2023 11:12:59 +0200 Subject: [PATCH 080/577] Specify no grad in inference --- napari_cellseg3d/code_models/model_workers.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 7d09e288..d1d73f01 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -488,16 +488,17 @@ def model_output_wrapper(inputs): result = model(inputs) return post_process_transforms(result) - outputs = sliding_window_inference( - inputs, - roi_size=window_size, - sw_batch_size=1, # TODO add param - predictor=model_output_wrapper, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) + with torch.no_grad(): + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) except Exception as e: logger.error(e, exc_info=True) logger.debug("failed to run sliding window inference") @@ -1412,16 +1413,17 @@ def train(self): ) self.log("Performing validation...") try: - val_outputs = sliding_window_inference( - val_inputs, - roi_size=size, - sw_batch_size=self.config.batch_size, - predictor=model, - overlap=0.25, - sw_device=self.config.device, - device=self.config.device, - progress=True, - ) + with torch.no_grad(): + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) except Exception as e: self.raise_error(e, "Error during validation") logger.debug( From 523210493a7c77d5034ca7acbca244eb6c929cfe Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 22 Apr 2023 14:12:32 +0200 Subject: [PATCH 081/577] First functional WNet inference, no CRF --- napari_cellseg3d/code_models/model_workers.py | 46 +++++++++++---- .../code_models/models/model_WNet.py | 3 +- .../code_plugins/plugin_model_inference.py | 57 +++++++++++-------- 3 files changed, 71 insertions(+), 35 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index d1d73f01..87fc5455 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -198,7 +198,7 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None - stats: ImageStats = None + stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -539,7 +539,10 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - semantic_labels = np.swapaxes(semantic_labels, 0, 2) + total_dims = len(semantic_labels.shape) - 3 + semantic_labels = np.swapaxes( + semantic_labels, 0 + total_dims, 2 + total_dims + ) return InferenceResult( image_id=i + 1, @@ -582,8 +585,10 @@ def save_image( ): if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" + filetype = self.config.filetype else: original_filename = "_" + filetype = "" time = utils.get_date_time() @@ -594,7 +599,7 @@ def save_image( + original_filename + self.config.model_info.name + f"_{time}_" - + self.config.filetype + + filetype ) try: imwrite(file_path, image) @@ -619,22 +624,35 @@ def aniso_transform(self, image): else: return image - def instance_seg(self, to_instance, image_id=0, original_filename="layer"): + def instance_seg( + self, to_instance, image_id=0, original_filename="layer", channel=None + ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method instance_labels = method.run_method(image=to_instance) + if channel is not None: + channel_id = f"_{channel}" + else: + channel_id = "" + + if self.config.filetype == "": + filetype = "" + else: + filetype = "_" + self.config.filetype + instance_filepath = ( self.config.results_path + "/" + f"Instance_seg_labels_{image_id}_" + original_filename + + channel_id + "_" + self.config.model_info.name - + f"_{utils.get_date_time()}_" - + self.config.filetype + + f"_{utils.get_date_time()}" + + filetype ) imwrite(instance_filepath, instance_labels) @@ -699,13 +717,21 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, stats = self.get_instance_result(out, from_layer=True) + instance_labels_results = [] + stats_results = [] + + for channel in out: + instance_labels, stats = self.get_instance_result( + channel, from_layer=True + ) + instance_labels_results.append(instance_labels) + stats_results.append(stats) return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels, + instance_labels=instance_labels_results, from_layer=True, - stats=stats, + stats=stats_results, ) # @thread_worker(connect={"errored": self.raise_error}) @@ -762,7 +788,7 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=dims, + input_img_size=[dims, dims, dims], device=self.config.device, num_classes=self.config.model_info.num_classes, ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 63a91b10..dffa3b44 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,8 @@ def __init__( ) def forward(self, x): - """Forward pass of the W-Net model.""" + """Forward ENCODER pass of the W-Net model. + Done this way to allow inference on the encoder only when called by sliding_window_inference.""" enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 060964df..55448193 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -734,37 +734,46 @@ def on_yield(self, result: InferenceResult): ) if result.instance_labels is not None: - labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + for i, labels in enumerate(result.instance_labels): + # labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - stats = result.stats + from napari_cellseg3d.utils import LOGGER as log - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + log.debug(f"len stats : {len(result.stats)}") - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + for i, stats in enumerate(result.stats): + # stats = result.stats - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) + + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) + + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) except Exception as e: self.on_error(e) From 7378c0a5e59bc29f43e868a966c528afbd4045cf Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:48:12 +0200 Subject: [PATCH 082/577] Create test_models.py --- napari_cellseg3d/_tests/test_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 napari_cellseg3d/_tests/test_models.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py new file mode 100644 index 00000000..e2ba32e0 --- /dev/null +++ b/napari_cellseg3d/_tests/test_models.py @@ -0,0 +1,13 @@ +from napari_cellseg3d.config import MODEL_LIST + + +def test_model_list(): + for model_name in MODEL_LIST.keys(): + dims = 128 + test = MODEL_LIST[model_name]( + input_img_size=[dims, dims, dims], + in_channels=1, + out_channels=1, + dropout_prob=0.3, + ) + assert isinstance(test, MODEL_LIST[model_name]) From c5cebd809cbeb7ad66fe9b3827f02410175743fb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:42:56 +0200 Subject: [PATCH 083/577] Run full suite of pre-commit hooks --- docs/res/guides/custom_model_template.rst | 2 -- napari_cellseg3d/code_models/model_instance_seg.py | 6 ++---- napari_cellseg3d/code_models/model_workers.py | 5 ++--- napari_cellseg3d/code_models/models/model_SwinUNetR.py | 1 + napari_cellseg3d/code_models/models/model_WNet.py | 3 ++- napari_cellseg3d/code_models/models/wnet/crf.py | 8 +++----- napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py | 8 ++++---- napari_cellseg3d/config.py | 1 - napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 - 9 files changed, 14 insertions(+), 21 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index 9bad49b0..218795b1 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -9,5 +9,3 @@ To add a custom model, you will need a **.py** file with the following structure **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute :: - - diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 4f702a4b..047f23ac 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -9,16 +9,14 @@ from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - - # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 87fc5455..f33ec541 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1,8 +1,8 @@ import platform +import typing as t from dataclasses import dataclass from math import ceil from pathlib import Path -import typing as t import numpy as np import torch @@ -36,10 +36,9 @@ from monai.transforms import Zoom from monai.utils import set_determinism +# from napari.qt.threading import thread_worker # threads from napari.qt.threading import GeneratorWorker - -# from napari.qt.threading import thread_worker from napari.qt.threading import WorkerBaseSignals # Qt diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index f38409b8..05819e22 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,5 @@ from monai.networks.nets import SwinUNETR + from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index dffa3b44..750b8bdb 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -22,7 +22,8 @@ def __init__( def forward(self, x): """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference.""" + Done this way to allow inference on the encoder only when called by sliding_window_inference. + """ enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index ca11fba2..2ac0875d 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,11 +12,9 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - unary_from_softmax, - create_pairwise_gaussian, - create_pairwise_bilateral, -) +from pydensecrf.utils import create_pairwise_bilateral +from pydensecrf.utils import create_pairwise_gaussian +from pydensecrf.utils import unary_from_softmax __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 6a625355..4e84579f 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -1,15 +1,15 @@ """ Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. -The implementation was adapted and approximated to reduce computational and memory cost. +The implementation was adapted and approximated to reduce computational and memory cost. This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. """ import math + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -import numpy as np from scipy.stats import norm __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" @@ -56,7 +56,7 @@ def __init__(self, data_shape, device, o_i, o_x, radius=None): # self.distances, self.indexes = self.get_distances() """ - + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 32683016..81907f31 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -16,7 +16,6 @@ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ from napari_cellseg3d.code_models.models.model_WNet import WNet_ - from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..48249a94 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,5 +1,4 @@ import os - import napari import numpy as np import scipy.ndimage as ndimage From 686704041f3195e09a3182009dfdbce564056206 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:27:18 +0200 Subject: [PATCH 084/577] Patch for tests action + style --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/code_models/model_instance_seg.py | 6 ++++-- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 + napari_cellseg3d/utils.py | 1 + 5 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index ea0a1e46..88a67ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -16,6 +16,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet workflow_dispatch: jobs: diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 047f23ac..c9b09d9c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -9,14 +9,16 @@ from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 750b8bdb..4a9ff70d 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -11,7 +11,7 @@ def __init__( out_channels=1, num_classes=2, device="cpu", - **kwargs + **kwargs, ): super().__init__( device=device, diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 48249a94..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ import os + import napari import numpy as np import scipy.ndimage as ndimage diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index bc6203be..2ae44896 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,6 +2,7 @@ import warnings from datetime import datetime from pathlib import Path + import numpy as np from monai.transforms import Zoom from skimage import io From f3cc365584c00413aa0a342266af339b4b421003 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 16:03:29 +0200 Subject: [PATCH 085/577] Add softNCuts basic test --- napari_cellseg3d/_tests/test_models.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index e2ba32e0..9280b230 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,3 +1,6 @@ +import torch + +from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST @@ -11,3 +14,20 @@ def test_model_list(): dropout_prob=0.3, ) assert isinstance(test, MODEL_LIST[model_name]) + + +def test_soft_ncuts_loss(): + dims = 8 + labels = torch.rand([1, 1, dims, dims, dims]) + + loss = SoftNCutsLoss( + data_shape=[dims, dims, dims], + device="cpu", + o_i=4, + o_x=4, + radius=2, + ) + + res = loss.forward(labels, labels) + assert isinstance(res, torch.Tensor) + # assert res > 0 From 92f16c212401798e9ac316363c0012fb61ad4dfe Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 09:41:15 +0200 Subject: [PATCH 086/577] Added crf Co-Authored-By: Nevexios <72894299+nevexios@users.noreply.github.com> --- napari_cellseg3d/code_models/crf.py | 122 ++++++++++++++++++++++++++++ pyproject.toml | 3 + 2 files changed, 125 insertions(+) create mode 100644 napari_cellseg3d/code_models/crf.py diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py new file mode 100644 index 00000000..13f489c7 --- /dev/null +++ b/napari_cellseg3d/code_models/crf.py @@ -0,0 +1,122 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" + +from warnings import warn + +import numpy as np + +try: + import pydensecrf.densecrf as dcrf + from pydensecrf.utils import create_pairwise_bilateral + from pydensecrf.utils import create_pairwise_gaussian + from pydensecrf.utils import unary_from_softmax + + CRF_INSTALLED = True +except ImportError: + warn( + "pydensecrf not installed, CRF post-processing will not be available. " + "Please install by running pip install cellseg3d[crf]" + ) + CRF_INSTALLED = False + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) diff --git a/pyproject.toml b/pyproject.toml index aabae448..47814de6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ profile = "black" line_length = 79 [project.optional-dependencies] +crf = [ +"git+https://github.com/lucasb-eyer/pydensecrf.git", +] dev = [ "isort", "black", From 1244d40fb4aa93699000f1afacdc5338e197f2de Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 10:08:46 +0200 Subject: [PATCH 087/577] More pre-commit checks --- .pre-commit-config.yaml | 10 +-- napari_cellseg3d/_tests/fixtures.py | 6 +- napari_cellseg3d/_tests/test_plugin_utils.py | 6 +- napari_cellseg3d/_tests/test_utils.py | 25 ++++--- .../_tests/test_weight_download.py | 6 +- napari_cellseg3d/code_models/crf.py | 11 +-- .../code_models/model_framework.py | 16 ++--- .../code_models/model_instance_seg.py | 13 ++-- napari_cellseg3d/code_models/model_workers.py | 64 +++++++++-------- .../code_models/models/unet/model.py | 4 +- .../code_models/models/wnet/crf.py | 8 ++- napari_cellseg3d/code_plugins/plugin_base.py | 6 +- .../code_plugins/plugin_convert.py | 16 ++--- napari_cellseg3d/code_plugins/plugin_crop.py | 9 ++- .../code_plugins/plugin_helper.py | 6 +- .../code_plugins/plugin_metrics.py | 3 +- .../code_plugins/plugin_model_inference.py | 20 +++--- .../code_plugins/plugin_model_training.py | 31 ++++---- .../code_plugins/plugin_review.py | 12 ++-- .../code_plugins/plugin_review_dock.py | 11 ++- .../code_plugins/plugin_utilities.py | 18 ++--- napari_cellseg3d/config.py | 8 +-- .../dev_scripts/artefact_labeling.py | 3 +- napari_cellseg3d/dev_scripts/convert.py | 3 +- .../dev_scripts/correct_labels.py | 3 +- napari_cellseg3d/dev_scripts/drafts.py | 3 +- napari_cellseg3d/dev_scripts/thread_test.py | 16 +++-- napari_cellseg3d/interface.py | 70 +++++++++---------- napari_cellseg3d/utils.py | 61 ++++++++-------- pyproject.toml | 7 +- 30 files changed, 237 insertions(+), 238 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..61ecaae5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,11 +5,11 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] +# - repo: https://github.com/pycqa/isort +# rev: 5.12.0 +# hooks: +# - id: isort +# args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index bd6b0ac7..b3044799 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -1,7 +1,7 @@ -import warnings - from qtpy.QtWidgets import QTextEdit +from napari_cellseg3d.utils import LOGGER as logger + class LogFixture(QTextEdit): """Fixture for testing, replaces napari_cellseg3d.interface.Log in model_workers during testing""" @@ -13,7 +13,7 @@ def print_and_log(self, text, printing=None): print(text) def warn(self, warning): - warnings.warn(warning) + logger.warning(warning) def error(self, e): raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index cbfd97b2..0abcf387 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -3,8 +3,10 @@ import numpy as np from tifffile import imread -from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS +from napari_cellseg3d.code_plugins.plugin_utilities import ( + UTILITIES_WIDGETS, + Utilities, +) def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index dc57b940..f2a9d32c 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,8 +1,7 @@ import os -import warnings +from functools import partial import numpy as np -import pytest import torch from napari_cellseg3d import utils @@ -33,6 +32,10 @@ def test_fill_list_in_between(): assert utils.fill_list_in_between(list, 2, "") == res + fill = partial(utils.fill_list_in_between, n=2, fill_value="") + + assert fill(list) == res + def test_align_array_sizes(): im = np.zeros((128, 512, 256)) @@ -79,15 +82,15 @@ def test_get_padding_dim(): tensor = torch.randn(2000, 30, 40) size = tensor.size() - warn = warnings.warn( - "Warning : a very large dimension for automatic padding has been computed.\n" - "Ensure your images are of an appropriate size and/or that you have enough memory." - "The padding value is currently 2048." - ) - - pad = utils.get_padding_dim(size) - - pytest.warns(warn, (lambda: utils.get_padding_dim(size))) + # warn = logger.warning( + # "Warning : a very large dimension for automatic padding has been computed.\n" + # "Ensure your images are of an appropriate size and/or that you have enough memory." + # "The padding value is currently 2048." + # ) + # + # pad = utils.get_padding_dim(size) + # + # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) assert pad == [2048, 32, 64] diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index 51189e4b..972550e9 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,5 +1,7 @@ -from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR -from napari_cellseg3d.code_models.model_workers import WeightsDownloader +from napari_cellseg3d.code_models.model_workers import ( + PRETRAINED_WEIGHTS_DIR, + WeightsDownloader, +) # DISABLED, causes GitHub actions to freeze diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 13f489c7..fc1e0b90 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -16,15 +16,18 @@ try: import pydensecrf.densecrf as dcrf - from pydensecrf.utils import create_pairwise_bilateral - from pydensecrf.utils import create_pairwise_gaussian - from pydensecrf.utils import unary_from_softmax + from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, + ) CRF_INSTALLED = True except ImportError: warn( "pydensecrf not installed, CRF post-processing will not be available. " - "Please install by running pip install cellseg3d[crf]" + "Please install by running pip install cellseg3d[crf]", + stacklevel=1, ) CRF_INSTALLED = False diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 1e6b934a..1c3abe3f 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,20 +1,16 @@ -import warnings from pathlib import Path import napari import torch # Qt -from qtpy.QtWidgets import QProgressBar -from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QProgressBar, QSizePolicy # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -137,11 +133,11 @@ def save_log(self): f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) else: - warnings.warn(f"No logger defined : Log is {self.log}") + logger.warning(f"No logger defined : Log is {self.log}") def save_log_to_path(self, path): """Saves the worker log to a specific path. Cannot be used with connect. @@ -163,7 +159,7 @@ def save_log_to_path(self, path): f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) @@ -172,7 +168,7 @@ def display_status_report(self): (usually when starting a worker)""" # if self.container_report is None or self.log is None: - # warnings.warn( + # logger.warning( # "Status report widget has been closed. Trying to re-instantiate..." # ) # self.container_report = QWidget() diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index c9b09d9c..0ad3c595 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,19 +1,18 @@ from dataclasses import dataclass +from functools import partial from typing import List import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label -from skimage.measure import regionprops +from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis # from skimage.measure import marching_cubes # from skimage.measure import mesh_surface_area @@ -359,8 +358,10 @@ def sphericity(region): volume = [region.area for region in properties] - def fill(lst, n=len(properties) - 1): - return fill_list_in_between(lst, n, "") + # def fill(lst, n=len(properties) - 1): + # return fill_list_in_between(lst, n, "") + + fill = partial(fill_list_in_between, n=len(properties) - 1, fill_value="") if len(volume_image.flatten()) != 0: ratio = fill([np.sum(volume) / len(volume_image.flatten())]) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index f33ec541..842a86f8 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -8,38 +8,41 @@ import torch # MONAI -from monai.data import CacheDataset -from monai.data import DataLoader -from monai.data import Dataset -from monai.data import decollate_batch -from monai.data import pad_list_data_collate -from monai.data import PatchDataset +from monai.data import ( + CacheDataset, + DataLoader, + Dataset, + PatchDataset, + decollate_batch, + pad_list_data_collate, +) from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric -from monai.transforms import AddChannel -from monai.transforms import AsDiscrete -from monai.transforms import Compose -from monai.transforms import EnsureChannelFirstd -from monai.transforms import EnsureType -from monai.transforms import EnsureTyped -from monai.transforms import LoadImaged -from monai.transforms import Orientationd -from monai.transforms import Rand3DElasticd -from monai.transforms import RandAffined -from monai.transforms import RandFlipd -from monai.transforms import RandRotate90d -from monai.transforms import RandShiftIntensityd -from monai.transforms import RandSpatialCropSamplesd -from monai.transforms import SpatialPad -from monai.transforms import SpatialPadd -from monai.transforms import ToTensor -from monai.transforms import Zoom +from monai.transforms import ( + AddChannel, + AsDiscrete, + Compose, + EnsureChannelFirstd, + EnsureType, + EnsureTyped, + LoadImaged, + Orientationd, + Rand3DElasticd, + RandAffined, + RandFlipd, + RandRotate90d, + RandShiftIntensityd, + RandSpatialCropSamplesd, + SpatialPad, + SpatialPadd, + ToTensor, + Zoom, +) from monai.utils import set_determinism # from napari.qt.threading import thread_worker # threads -from napari.qt.threading import GeneratorWorker -from napari.qt.threading import WorkerBaseSignals +from napari.qt.threading import GeneratorWorker, WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -47,11 +50,12 @@ from tqdm import tqdm # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ImageStats -from napari_cellseg3d.code_models.model_instance_seg import volume_stats +from napari_cellseg3d.code_models.model_instance_seg import ( + ImageStats, + volume_stats, +) logger = utils.LOGGER diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index 6cc76be6..9614a555 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -1,12 +1,10 @@ import torch.nn as nn from napari_cellseg3d.code_models.models.unet.buildingblocks import ( + DoubleConv, create_decoders, -) -from napari_cellseg3d.code_models.models.unet.buildingblocks import ( create_encoders, ) -from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index 2ac0875d..004db3a1 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,9 +12,11 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import create_pairwise_bilateral -from pydensecrf.utils import create_pairwise_gaussian -from pydensecrf.utils import unary_from_softmax +from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, +) __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 02c9fbff..7f6317b7 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from pathlib import Path @@ -6,8 +5,7 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QTabWidget -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QTabWidget, QWidget # local from napari_cellseg3d import interface as ui @@ -403,7 +401,7 @@ def load_dataset_paths(self): file_paths = sorted(Path(directory).glob("*" + filetype)) if len(file_paths) == 0: - warnings.warn( + logger.warning( f"The folder does not contain any compatible {filetype} files.\n" f"Please check the validity of the folder and images." ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 547e0233..edbeaa6e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,18 +1,18 @@ -import warnings from pathlib import Path import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import threshold -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceWidgets, + clear_small_objects, + threshold, + to_semantic, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -84,7 +84,7 @@ def show_result(viewer, layer, image, name): logger.debug("Added resulting label layer") viewer.add_labels(image, name=name) else: - warnings.warn( + logger.warning( f"Results not shown, unsupported layer type {type(layer)}" ) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 789be9e5..4b4e7d82 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import napari @@ -248,7 +247,7 @@ def _start(self): # maybe use singletons or make docked widgets attributes that are hidden upon opening if not self._check_ready(): - warnings.warn("Please select at least one valid layer !") + logger.warning("Please select at least one valid layer !") return # self._viewer.window.remove_dock_widget(self.parent()) # no need to close utils ? @@ -332,7 +331,7 @@ def add_isotropic_layer( self, layer, colormap="inferno", - contrast_lim=[200, 1000], # TODO generalize ? + contrast_lim=(200, 1000), # TODO generalize ? opacity=0.7, visible=True, ): @@ -438,8 +437,8 @@ def _add_crop_sliders( for i in range(len(crop_sizes)): if crop_sizes[i] > im1_stack.shape[i]: crop_sizes[i] = im1_stack.shape[i] - warnings.warn( - f"WARNING : Crop dimension in axis {i} was too large at {crop_sizes[i]}, it was set to {im1_stack.shape[i]}" + logger.warning( + f"Crop dimension in axis {i} was too large at {crop_sizes[i]}, it was set to {im1_stack.shape[i]}" ) cropx, cropy, cropz = crop_sizes diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 083b269b..f8ac18ef 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -4,10 +4,8 @@ # Qt from qtpy.QtCore import QSize -from qtpy.QtGui import QIcon -from qtpy.QtGui import QPixmap -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtGui import QIcon, QPixmap +from qtpy.QtWidgets import QVBoxLayout, QWidget # local from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index b2356526..114025f6 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -5,8 +5,7 @@ FigureCanvasQTAgg as FigureCanvas, ) from matplotlib.figure import Figure -from monai.transforms import SpatialPad -from monai.transforms import ToTensor +from monai.transforms import SpatialPad, ToTensor from tifffile import imread from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 55448193..c509fe61 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,4 +1,3 @@ -import warnings from functools import partial import napari @@ -6,14 +5,19 @@ import pandas as pd # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceMethod, + InstanceWidgets, +) +from napari_cellseg3d.code_models.model_workers import ( + InferenceResult, + InferenceWorker, +) + +logger = utils.LOGGER class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -538,7 +542,7 @@ def start(self): if not self._check_results_path(save_path): msg = f"ERROR: please set valid results path. Current path is {save_path}" self.log.print_and_log(msg) - warnings.warn(msg) + logger.warning(msg) else: if self.results_path is None: self.results_path = save_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 7bd1b0bf..88991f43 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,5 +1,4 @@ import shutil -import warnings from functools import partial from pathlib import Path @@ -14,23 +13,26 @@ from matplotlib.figure import Figure # MONAI -from monai.losses import DiceCELoss -from monai.losses import DiceFocalLoss -from monai.losses import DiceLoss -from monai.losses import FocalLoss -from monai.losses import GeneralizedDiceLoss -from monai.losses import TverskyLoss +from monai.losses import ( + DiceCELoss, + DiceFocalLoss, + DiceLoss, + FocalLoss, + GeneralizedDiceLoss, + TverskyLoss, +) # Qt from qtpy.QtWidgets import QSizePolicy # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import TrainingReport -from napari_cellseg3d.code_models.model_workers import TrainingWorker +from napari_cellseg3d.code_models.model_workers import ( + TrainingReport, + TrainingWorker, +) NUMBER_TABS = 3 DEFAULT_PATCH_SIZE = 64 @@ -415,8 +417,7 @@ def check_ready(self): if self.images_filepaths != [] and self.labels_filepaths != []: return True else: - warnings.formatwarning = utils.format_Warning - warnings.warn("Image and label paths are not correctly set") + logger.warning("Image and label paths are not correctly set") return False def _build(self): @@ -784,7 +785,7 @@ def start(self): if not self.check_ready(): # issues a warning if not ready err = "Aborting, please set all required paths" self.log.print_and_log(err) - warnings.warn(err) + logger.warning(err) return if self.worker is not None: @@ -1040,7 +1041,7 @@ def _make_csv(self): size_column = range(1, self.worker_config.max_epochs + 1) if len(self.loss_values) == 0 or self.loss_values is None: - warnings.warn("No loss values to add to csv !") + logger.warning("No loss values to add to csv !") return self.df = pd.DataFrame( diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index a1a167a4..235595e4 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import matplotlib.pyplot as plt @@ -11,18 +10,15 @@ from matplotlib.figure import Figure # Qt -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QLineEdit, QSizePolicy from tifffile import imwrite # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -184,7 +180,7 @@ def check_image_data(self): if cfg.labels is not None: if cfg.image.shape != cfg.labels.shape: - warnings.warn( + logger.warning( "Image and label dimensions do not match ! Please load matching images" ) @@ -240,7 +236,7 @@ def run_review(self): self._reset() previous_viewer.close() except ValueError as e: - warnings.warn( + logger.warning( f"An exception occurred : {e}. Please ensure you have entered all required parameters." ) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index 8a25d6a6..6cee7c94 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,14 +1,11 @@ -import warnings -from datetime import datetime -from datetime import timedelta +from datetime import datetime, timedelta from pathlib import Path import napari import pandas as pd # Qt -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QVBoxLayout, QWidget from napari_cellseg3d import interface as ui from napari_cellseg3d import utils @@ -18,7 +15,7 @@ GUI_MINIMUM_HEIGHT = 300 TIMER_FORMAT = "%H:%M:%S" - +logger = utils.LOGGER """ plugin_dock.py ==================================== @@ -266,7 +263,7 @@ def update_dm(self, slice_num): def button_func(self): # updates csv every time you press button... if self.viewer.dims.ndisplay != 2: # TODO test if undefined behaviour or if okay - warnings.warn("Please switch back to 2D mode !") + logger.warning("Please switch back to 2D mode !") return self.update_time_csv() diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index e141bfe5..462ee450 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,17 +2,17 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget # local import napari_cellseg3d.interface as ui -from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils -from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils -from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils +from napari_cellseg3d.code_plugins.plugin_convert import ( + AnisoUtils, + RemoveSmallUtils, + ThresholdUtils, + ToInstanceUtils, + ToSemanticUtils, +) from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { @@ -82,7 +82,7 @@ def _update_visibility(self): # print("vis. updated") # print(self.utils_widgets) self._hide_all() - for i, w in enumerate(self.utils_widgets): + for _i, w in enumerate(self.utils_widgets): if isinstance(w, widget_class): w.setVisible(True) w.adjustSize() diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 81907f31..1f3fd4c1 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -1,9 +1,7 @@ import datetime -import warnings from dataclasses import dataclass from pathlib import Path -from typing import List -from typing import Optional +from typing import List, Optional import napari import numpy as np @@ -85,9 +83,9 @@ def get_model(self): return MODEL_LIST[self.name] except KeyError as e: msg = f"Model {self.name} is not defined" - warnings.warn(msg) logger.warning(msg) - raise KeyError(e) + logger.warning(msg) + raise KeyError from e @staticmethod def get_model_name_list(): diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..b4712aec 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -4,8 +4,7 @@ import numpy as np import scipy.ndimage as ndimage from skimage.filters import threshold_otsu -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from napari_cellseg3d.code_models.model_instance_seg import binary_watershed diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 479a07dd..641de627 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,8 +2,7 @@ import os import numpy as np -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" # output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2f079d09..2ab60332 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -8,8 +8,7 @@ import numpy as np import scipy.ndimage as ndimage from napari.qt.threading import thread_worker -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels diff --git a/napari_cellseg3d/dev_scripts/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py index adfb7914..cdd02256 100644 --- a/napari_cellseg3d/dev_scripts/drafts.py +++ b/napari_cellseg3d/dev_scripts/drafts.py @@ -1,8 +1,7 @@ import napari import numpy as np from magicgui import magicgui -from napari.types import ImageData -from napari.types import LabelsData +from napari.types import ImageData, LabelsData @magicgui(call_button="Run Threshold") diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 998645cb..20668125 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -3,13 +3,15 @@ import napari import numpy as np from napari.qt.threading import thread_worker -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QProgressBar -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import ( + QGridLayout, + QLabel, + QProgressBar, + QPushButton, + QTextEdit, + QVBoxLayout, + QWidget, +) @thread_worker diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 484d137d..4b560a6d 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,8 +1,6 @@ import threading -import warnings from functools import partial -from typing import List -from typing import Optional +from typing import List, Optional import napari @@ -11,32 +9,30 @@ from qtpy import QtCore # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt -from qtpy.QtCore import QUrl -from qtpy.QtGui import QCursor -from qtpy.QtGui import QDesktopServices -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QCheckBox -from qtpy.QtWidgets import QComboBox -from qtpy.QtWidgets import QDoubleSpinBox -from qtpy.QtWidgets import QFileDialog -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QGroupBox -from qtpy.QtWidgets import QHBoxLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QMenu -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QRadioButton -from qtpy.QtWidgets import QScrollArea -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QSlider -from qtpy.QtWidgets import QSpinBox -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtCore import QObject, Qt, QUrl +from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLayout, + QLineEdit, + QMenu, + QPushButton, + QRadioButton, + QScrollArea, + QSizePolicy, + QSlider, + QSpinBox, + QTextEdit, + QVBoxLayout, + QWidget, +) # Local from napari_cellseg3d import utils @@ -290,10 +286,10 @@ def print_and_log(self, text, printing=True): self.lock.release() def warn(self, warning): - """Show warnings.warn from another thread""" + """Show logger.warning from another thread""" self.lock.acquire() try: - warnings.warn(warning) + logger.warning(warning) finally: self.lock.release() @@ -538,7 +534,7 @@ def _build_container(self): ) def _warn_outside_bounds(self, default): - warnings.warn( + logger.warning( f"Default value {default} was outside of the ({self.minimum()}:{self.maximum()}) range" ) @@ -583,7 +579,7 @@ def slider_value(self): try: return self.value() / self._divide_factor except ZeroDivisionError as e: - raise ZeroDivisionError( + raise ZeroDivisionError from ( f"Divide factor cannot be 0 for Slider : {e}" ) @@ -793,7 +789,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: - warnings.warn("Please select a valid layer !") + logger.warning("Please select a valid layer !") return return self._viewer.layers[self.layer_name()].data @@ -1032,7 +1028,7 @@ def make_n_spinboxes( raise ValueError("Cannot make less than 2 spin boxes") boxes = [] - for i in range(n): + for _i in range(n): box = class_(min, max, default, step, parent, fixed) boxes.append(box) return boxes @@ -1190,7 +1186,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, - possible_paths: list = [], + possible_paths: list = (), filetype: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1214,7 +1210,7 @@ def open_file_dialog( def open_folder_dialog( widget, - possible_paths: list = [], + possible_paths: list = (), ): default_path = utils.parse_default_path(possible_paths) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 2ae44896..c6c05ad1 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,5 +1,4 @@ import logging -import warnings from datetime import datetime from pathlib import Path @@ -235,7 +234,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # warnings.warn( + # logger.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) @@ -243,7 +242,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): pad = 2**n n += 1 if pad >= 256: - warnings.warn( + LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." f"The padding value is currently {pad}." @@ -343,14 +342,14 @@ def annotation_to_input(label_ermito): # pass -def fill_list_in_between(lst, n, elem): +def fill_list_in_between(lst, n, fill_value): """Fills a list with n * elem between each member of list. Example with list = [1,2,3], n=2, elem='&' : returns [1, &, &,2,&,&,3,&,&] Args: lst: list to fill n: number of elements to add - elem: added n times after each element of list + fill_value: added n times after each element of list Returns : Filled list @@ -359,13 +358,13 @@ def fill_list_in_between(lst, n, elem): for i in range(len(lst)): temp_list = [lst[i]] while len(temp_list) < n + 1: - temp_list.append(elem) + temp_list.append(fill_value) if i < len(lst) - 1: new_list += temp_list else: new_list.append(lst[i]) - for j in range(n): - new_list.append(elem) + for _j in range(n): + new_list.append(fill_value) return new_list @@ -533,26 +532,26 @@ def select_train_data(dataframe, ori_imgs, label_imgs, ori_filenames): return np.array(train_ori_imgs), np.array(train_label_imgs) -def format_Warning(message, category, filename, lineno, line=""): - """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` - - Args: - message: warning message - category: which type of warning has been raised - filename: file - lineno: line number - line: unused - - Returns: format - - """ - return ( - str(filename) - + ":" - + str(lineno) - + ": " - + category.__name__ - + ": " - + str(message) - + "\n" - ) +# def format_Warning(message, category, filename, lineno, line=""): +# """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` +# +# Args: +# message: warning message +# category: which type of warning has been raised +# filename: file +# lineno: line number +# line: unused +# +# Returns: format +# +# """ +# return ( +# str(filename) +# + ":" +# + str(lineno) +# + ": " +# + category.__name__ +# + ": " +# + str(message) +# + "\n" +# ) diff --git a/pyproject.toml b/pyproject.toml index 47814de6..21c40a6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,12 @@ where = ["."] "*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] [tool.ruff] -# Never enforce `E501` (line length violations). +select = [ + "E", "F", "W", + "I", + "B", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) ignore = ["E501", "E741"] [tool.black] From 43eaaa955374e0d9641e21b13e8921e745eb5fce Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:29:42 +0200 Subject: [PATCH 088/577] Functional CRF --- napari_cellseg3d/_tests/test_models.py | 39 +++ napari_cellseg3d/code_models/crf.py | 98 ++++++- .../code_models/model_instance_seg.py | 6 +- napari_cellseg3d/code_models/model_workers.py | 48 +++- napari_cellseg3d/code_plugins/plugin_base.py | 25 +- .../code_plugins/plugin_convert.py | 102 ++----- napari_cellseg3d/code_plugins/plugin_crf.py | 262 ++++++++++++++++++ napari_cellseg3d/code_plugins/plugin_crop.py | 7 +- .../code_plugins/plugin_model_inference.py | 32 ++- .../code_plugins/plugin_utilities.py | 15 +- napari_cellseg3d/config.py | 16 ++ napari_cellseg3d/interface.py | 19 +- napari_cellseg3d/utils.py | 81 +++++- 13 files changed, 629 insertions(+), 121 deletions(-) create mode 100644 napari_cellseg3d/code_plugins/plugin_crf.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 9280b230..1fc15872 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,9 +1,18 @@ +import numpy as np import torch +from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +def test_correct_shape_for_crf(): + test = np.random.rand(1, 1, 8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + test = np.random.rand(8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + + def test_model_list(): for model_name in MODEL_LIST.keys(): dims = 128 @@ -31,3 +40,33 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) # assert res > 0 + + +def test_crf(qtbot): + dims = 8 + mock_image = np.random.rand(1, dims, dims, dims) + mock_label = np.random.rand(2, dims, dims, dims) + + crf = CRFWorker(mock_image, mock_label) + + def on_yield(result): + assert isinstance(result, np.ndarray) + assert result.shape[-3:] == mock_label.shape[-3:] + + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) + + mock_image = mock_image[0] + mock_label = mock_label[0] + + crf = CRFWorker(mock_image, mock_label) + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index fc1e0b90..a0146a5e 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -9,11 +9,8 @@ Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. """ - from warnings import warn -import numpy as np - try: import pydensecrf.densecrf as dcrf from pydensecrf.utils import ( @@ -31,6 +28,12 @@ ) CRF_INSTALLED = False + +import numpy as np +from napari.qt.threading import GeneratorWorker + +from napari_cellseg3d.config import CRFConfig + __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ "Yves Paychère", @@ -49,6 +52,16 @@ ] +def correct_shape_for_crf(image): + if len(image.shape) == 4: + return image + if len(image.shape) > 4: + image = np.squeeze(image, axis=0) + if len(image.shape) < 4: + image = np.expand_dims(image, axis=0) + return correct_shape_for_crf(image) + + def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): """CRF post-processing step for the W-Net, applied to a batch of images. @@ -62,6 +75,8 @@ def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): Returns: np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. """ + if not CRF_INSTALLED: + return None return np.stack( [ @@ -83,10 +98,16 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + w1 (float): weight of the appearance/bilateral kernel. + w2 (float): weight of the smoothness/gaussian kernel. Returns: np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. """ + + if not CRF_INSTALLED: + return None + d = dcrf.DenseCRF( image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] ) @@ -123,3 +144,74 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): return np.array(Q).reshape( (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) ) + + +def crf_with_config(image, prob, config: CRFConfig = None): + if config is None: + config = CRFConfig() + if image.shape[-3:] != prob.shape[-3:]: + raise ValueError( + f"Image and probability shapes do not match: {image.shape} vs {prob.shape}" + f" (expected {image.shape[-3:]} == {prob.shape[-3:]})" + ) + + image = correct_shape_for_crf(image) + + return crf( + image, + prob, + config.sa, + config.sb, + config.sg, + config.w1, + config.w2, + config.n_iters, + ) + + +class CRFWorker(GeneratorWorker): + """Worker for the CRF post-processing step for the W-Net.""" + + def __init__( + self, + images_list, + labels_list, + config: CRFConfig = None, + log=None, + ): + super().__init__(self._run_crf_job) + + self.images = images_list + self.labels = labels_list + if config is None: + self.config = CRFConfig() + else: + self.config = config + self.log = log + + # TODO(cyril) : add progress bar into log ? or do it in inference + def _run_crf_job(self): + """Runs the CRF post-processing step for the W-Net.""" + if not CRF_INSTALLED: + raise ImportError("pydensecrf is not installed.") + + for image, labels in zip(self.images, self.labels): + if len(image.shape) == 3: + image = np.expand_dims(image, axis=0) + + if len(labels.shape) == 3: + labels = np.expand_dims(labels, axis=0) + + if image.shape[-3:] != labels.shape[-3:]: + raise ValueError("Image and labels must have the same shape.") + + yield crf( + image, + labels, + self.config.sa, + self.config.sb, + self.config.sg, + self.config.w1, + self.config.w2, + n_iter=self.config.n_iters, + ) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 0ad3c595..eb660820 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -62,7 +62,7 @@ def __init__( 1, divide_factor=100, text_label="", - parent=None, + parent=widget_parent, ), ) self.sliders.append(getattr(self, widget)) @@ -73,7 +73,9 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(text_label="", parent=None), + ui.DoubleIncrementCounter( + text_label="", parent=widget_parent + ), ) self.counters.append(getattr(self, widget)) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 842a86f8..06285aea 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -52,6 +52,7 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import crf_with_config from napari_cellseg3d.code_models.model_instance_seg import ( ImageStats, volume_stats, @@ -201,6 +202,7 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None + crf_results: np.array = None stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -527,7 +529,8 @@ def create_inference_result( self, semantic_labels, instance_labels, - from_layer: bool, + crf_results=None, + from_layer: bool = False, original=None, stats=None, i=0, @@ -542,15 +545,19 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - total_dims = len(semantic_labels.shape) - 3 + extra_dims = len(semantic_labels.shape) - 3 semantic_labels = np.swapaxes( - semantic_labels, 0 + total_dims, 2 + total_dims + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims ) return InferenceResult( image_id=i + 1, original=original, instance_labels=instance_labels, + crf_results=crf_results, stats=stats, result=semantic_labels, model_name=self.config.model_info.name, @@ -585,6 +592,7 @@ def save_image( image, from_layer=False, i=0, + additional_info="", ): if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" @@ -598,7 +606,7 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"Prediction_{i+1}" + + f"{additional_info}_Prediction_{i+1}" + original_filename + self.config.model_info.name + f"_{time}_" @@ -680,6 +688,15 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): self.save_image(out, i=i) instance_labels, stats = self.get_instance_result(out, i=i) + if self.config.use_crf: + try: + crf_results = self.run_crf(inputs, out, image_id=i) + + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + crf_results = None + else: + crf_results = None original = np.array(inf_data["image"]).astype(np.float32) @@ -688,12 +705,29 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): return self.create_inference_result( out, instance_labels, + crf_results, from_layer=False, original=original, stats=stats, i=i, ) + def run_crf(self, image, labels, image_id=0): + self.log(f"IMAGE SHAPE : {image.shape}") + self.log(f"LABEL SHAPE : {labels.shape}") + + try: + crf_results = crf_with_config( + image, labels, config=self.config.crf_config + ) + self.save_image( + crf_results, i=image_id, additional_info="CRF", from_layer=True + ) + return crf_results + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + return None + def stats_csv(self, instance_labels): if self.config.compute_stats: stats = volume_stats(instance_labels) @@ -730,9 +764,15 @@ def inference_on_layer(self, image, model, post_process_transforms): instance_labels_results.append(instance_labels) stats_results.append(stats) + if self.config.use_crf: + crf_results = self.run_crf(image, out) + else: + crf_results = None + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels_results, + crf_results=crf_results, from_layer=True, stats=stats_results, ) diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 7f6317b7..26da7a42 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -46,15 +46,15 @@ def __init__( self.image_path = None """str: path to image folder""" - self.show_image_io = loads_images + self._show_image_io = loads_images self.label_path = None """str: path to label folder""" - self.show_label_io = loads_labels + self._show_label_io = loads_labels self.results_path = None """str: path to results folder""" - self.show_results_io = has_results + self._show_results_io = has_results self._default_path = [self.image_path, self.label_path] @@ -116,7 +116,7 @@ def show_menu(_, event): def _build_io_panel(self): self.io_panel = ui.GroupedWidget("Data") - + self.save_label = ui.make_label("Save location :", parent=self) # self.io_panel.setToolTip("IO Panel") ui.add_widgets( @@ -128,6 +128,7 @@ def _build_io_panel(self): self.filetype_choice, self.image_filewidget, self.labels_filewidget, + self.save_label, self.results_filewidget, ], ) @@ -137,25 +138,25 @@ def _build_io_panel(self): return self.io_panel def _remove_unused(self): - if not self.show_label_io: + if not self._show_label_io: self.labels_filewidget = None self.label_layer_loader = None - if not self.show_image_io: + if not self._show_image_io: self.image_layer_loader = None self.image_filewidget = None - if not self.show_results_io: + if not self._show_results_io: self.results_filewidget = None def _set_io_visibility(self): ################## # Show when layer is selected - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_layer_loader, self.layer_choice) else: self._hide_io_element(self.image_layer_loader) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.label_layer_loader, self.layer_choice) else: self._hide_io_element(self.label_layer_loader) @@ -165,15 +166,15 @@ def _set_io_visibility(self): f = self.folder_choice self._show_io_element(self.filetype_choice, f) - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_filewidget, f) else: self._hide_io_element(self.image_filewidget) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.labels_filewidget, f) else: self._hide_io_element(self.labels_filewidget) - if not self.show_results_io: + if not self._show_results_io: self._hide_io_element(self.results_filewidget) self.folder_choice.toggle() diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index edbeaa6e..44123e34 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -3,7 +3,7 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread import napari_cellseg3d.interface as ui from napari_cellseg3d import utils @@ -15,80 +15,12 @@ ) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -# TODO break down into multiple mini-widgets -# TODO create parent class for utils modules to avoid duplicates - -MAX_W = 200 -MAX_H = 1000 +MAX_W = ui.UTILS_MAX_WIDTH +MAX_H = ui.UTILS_MAX_HEIGHT logger = utils.LOGGER -def save_folder(results_path, folder_name, images, image_paths): - """ - Saves a list of images in a folder - - Args: - results_path: Path to the folder containing results - folder_name: Name of the folder containing results - images: List of images to save - image_paths: list of filenames of images - """ - results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False, parents=True) - - for file, image in zip(image_paths, images): - path = results_folder / Path(file).name - - imwrite( - path, - image, - ) - logger.info(f"Saved processed folder as : {results_folder}") - - -def save_layer(results_path, image_name, image): - """ - Saves an image layer at the specified path - - Args: - results_path: path to folder containing result - image_name: image name for saving - image: data array containing image - - Returns: - - """ - path = str(results_path / Path(image_name)) # TODO flexible filetype - logger.info(f"Saved as : {path}") - imwrite(path, image) - - -def show_result(viewer, layer, image, name): - """ - Adds layers to a viewer to show result to user - - Args: - viewer: viewer to add layer in - layer: type of the original layer the operation was run on, to determine whether it should be an Image or Labels layer - image: the data array containing the image - name: name of the added layer - - Returns: - - """ - if isinstance(layer, napari.layers.Image): - logger.debug("Added resulting image layer") - viewer.add_image(image, name=name) - elif isinstance(layer, napari.layers.Labels): - logger.debug("Added resulting label layer") - viewer.add_labels(image, name=name) - else: - logger.warning( - f"Results not shown, unsupported layer type {type(layer)}" - ) - - class AnisoUtils(BasePluginFolder): """Class to correct anisotropy in images""" @@ -154,12 +86,12 @@ def _start(self): data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) - save_layer( + utils.save_layer( self.results_path, f"isotropic_{layer.name}_{utils.get_date_time()}.tif", isotropic_image, ) - show_result( + utils.show_result( self._viewer, layer, isotropic_image, @@ -172,7 +104,7 @@ def _start(self): utils.resize(np.array(imread(file)), zoom) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"isotropic_results_{utils.get_date_time()}", images, @@ -253,12 +185,12 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"cleared_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) elif self.folder_choice.isChecked(): @@ -267,7 +199,7 @@ def _start(self): clear_small_objects(file, remove_size, is_file_path=True) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"small_removed_results_{utils.get_date_time()}", images, @@ -334,12 +266,12 @@ def _start(self): data = np.array(layer.data) semantic = to_semantic(data) - save_layer( + utils.save_layer( self.results_path, f"semantic_{layer.name}_{utils.get_date_time()}.tif", semantic, ) - show_result( + utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) elif self.folder_choice.isChecked(): @@ -348,7 +280,7 @@ def _start(self): to_semantic(file, is_file_path=True) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"semantic_results_{utils.get_date_time()}", images, @@ -419,7 +351,7 @@ def _start(self): data = np.array(layer.data) instance = self.instance_widgets.run_method(data) - save_layer( + utils.save_layer( self.results_path, f"instance_{layer.name}_{utils.get_date_time()}.tif", instance, @@ -434,7 +366,7 @@ def _start(self): self.instance_widgets.run_method(imread(file)) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"instance_results_{utils.get_date_time()}", images, @@ -514,12 +446,12 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"threshold_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) elif self.folder_choice.isChecked(): @@ -528,7 +460,7 @@ def _start(self): self.function(imread(file), remove_size) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"threshold_results_{utils.get_date_time()}", images, diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py new file mode 100644 index 00000000..3dbd47bb --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -0,0 +1,262 @@ +from functools import partial +from pathlib import Path + +import napari.layers +from qtpy.QtWidgets import QSizePolicy +from tqdm import tqdm + +from napari_cellseg3d import config, utils +from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage +from napari_cellseg3d.utils import LOGGER as logger + + +# TODO add CRF on folder +class CRFParamsWidget(ui.GroupedWidget): + """Use this widget when adding the crf as part of another widget (rather than a standalone widget)""" + + def __init__(self, parent=None): + super().__init__(title="CRF parameters", parent=parent) + ####### + # CRF params # + self.sa_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Alpha std" + ) + self.sb_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Beta std" + ) + self.sg_choice = ui.DoubleIncrementCounter( + default=1, parent=self, text_label="Gamma std" + ) + self.w1_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Weight appearance" + ) + self.w2_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Weight smoothness" + ) + self.n_iter_choice = ui.IntIncrementCounter( + default=5, parent=self, text_label="Number of iterations" + ) + ####### + self._build() + self._set_tooltips() + + def _build(self): + ui.add_widgets( + self.layout, + [ + # self.sa_choice.label, + self.sa_choice, + # self.sb_choice.label, + self.sb_choice, + # self.sg_choice.label, + self.sg_choice, + # self.w1_choice.label, + self.w1_choice, + # self.w2_choice.label, + self.w2_choice, + # self.n_iter_choice.label, + self.n_iter_choice, + ], + ) + self.set_layout() + + def _set_tooltips(self): + self.sa_choice.setToolTip( + "SA : Standard deviation of the Gaussian kernel in the appearance term." + ) + self.sb_choice.setToolTip( + "SB : Standard deviation of the Gaussian kernel in the smoothness term." + ) + self.sg_choice.setToolTip( + "SG : Standard deviation of the Gaussian kernel in the gradient term." + ) + self.w1_choice.setToolTip( + "W1 : Weight of the appearance term in the CRF." + ) + self.w2_choice.setToolTip( + "W2 : Weight of the smoothness term in the CRF." + ) + self.n_iter_choice.setToolTip("Number of iterations of the CRF.") + + def make_config(self): + return config.CRFConfig( + sa=self.sa_choice.value(), + sb=self.sb_choice.value(), + sg=self.sg_choice.value(), + w1=self.w1_choice.value(), + w2=self.w2_choice.value(), + n_iters=self.n_iter_choice.value(), + ) + + +class CRFWidget(BasePluginSingleImage): + def __init__(self, viewer, parent=None): + """ + Create a widget for CRF post-processing. + Args: + viewer: napari viewer to display the widget + parent: parent widget. Defaults to None. + """ + super().__init__(viewer, parent) + self._viewer = viewer + + self.start_button = ui.Button("Start", self._start, parent=self) + self.crf_params_widget = CRFParamsWidget(parent=self) + self.io_panel = self._build_io_panel() + self.io_panel.setVisible(False) + + self.results_filewidget.setVisible(True) + self.label_layer_loader.setVisible(True) + self.label_layer_loader.set_layer_type( + napari.layers.Image + ) # to load all crf-compatible inputs, not int only + self.image_layer_loader.setVisible(True) + self.start_button.setVisible(True) + + self.result_layer = None + self.result_name = None + self.crf_results = [] + + self.results_path = Path.home() / Path("cellseg3d/crf") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._container = ui.ContainerWidget(parent=self, l=11, t=11, r=11) + self.layout = self._container.layout + + self._build() + + self.worker = None + self.log = None + + def _build(self): + self.setMinimumWidth(100) + ui.add_widgets( + self.layout, + [ + self.image_layer_loader, + self.label_layer_loader, + self.save_label, + self.results_filewidget, + ui.make_label(""), + self.crf_params_widget, + ui.make_label(""), + self.start_button, + ], + ) + # self.io_panel.setLayout(self.io_panel.layout) + self.setLayout(self.layout) + + ui.ScrollArea.make_scrollable( + self.layout, self, max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT] + ) + self._container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + return self._container + + def make_config(self): + return self.crf_params_widget.make_config() + + def _check_ready(self): + if len(self.label_layer_loader.layer_list) < 1: + logger.warning("No label layer loaded") + return False + if len(self.image_layer_loader.layer_list) < 1: + logger.warning("No image layer loaded") + return False + + if len(self.label_layer_loader.layer_data().shape) < 3: + logger.warning("Label layer must be 3D") + return False + if len(self.image_layer_loader.layer_data().shape) < 3: + logger.warning("Image layer must be 3D") + return False + if ( + self.label_layer_loader.layer_data().shape[-3:] + != self.image_layer_loader.layer_data().shape[-3:] + ): + logger.warning("Image and label layers must have the same shape!") + return False + + return True + + def run_crf_on_batch(self, images_list: list, labels_list: list, log=None): + self.crf_results = [] + for image, label in zip(images_list, labels_list): + tqdm( + unit="B", + total=len(images_list), + position=0, + file=log, + ) + result = crf_with_config(image, label, self.make_config()) + self.crf_results.append(result) + return self.crf_results + + def _prepare_worker(self, images_list: list, labels_list: list): + self.worker = CRFWorker( + images_list=images_list, + labels_list=labels_list, + config=self.make_config(), + ) + + self.worker.started.connect(self._on_start) + self.worker.yielded.connect(partial(self._on_yield)) + self.worker.errored.connect(partial(self._on_error)) + self.worker.finished.connect(self._on_finish) + + def _start(self): + if not self._check_ready(): + return + + self.result_layer = self.label_layer_loader.layer() + self.result_name = self.label_layer_loader.layer_name() + + self.results_path.mkdir(exist_ok=True, parents=True) + + image_list = [self.image_layer_loader.layer_data()] + labels_list = [self.label_layer_loader.layer_data()] + [logger.debug(f"Image shape: {image.shape}") for image in image_list] + [ + logger.debug(f"Label shape: {labels.shape}") + for labels in labels_list + ] + + self._prepare_worker(image_list, labels_list) + + if self.worker.is_running: # if worker is running, tries to stop + logger.info("Stop request, waiting for previous job to finish") + self.start_button.setText("Stopping...") + self.worker.quit() + else: # once worker is started, update buttons + self.start_button.setText("Running...") + logger.info("Starting CRF...") + self.worker.start() + + def _on_yield(self, result): + self.crf_results.append(result) + + utils.save_layer( + self.results_filewidget.text_field.text(), + str(self.result_name + "_crf.tif"), + result, + ) + self._viewer.add_image( + result, + name="crf_" + self.result_name, + ) + + def _on_start(self): + self.crf_results = [] + + def _on_finish(self): + self.worker = None + + def _on_error(self, error): + logger.error(error) + self.start_button.setText("Start") + self.worker.quit() + self.worker = None diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 4b4e7d82..7363d91c 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -174,7 +174,12 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + ui.ScrollArea.make_scrollable( + layout, + self, + max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT], + min_wh=[200, 200], + ) self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index c509fe61..03381779 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -16,6 +16,7 @@ InferenceResult, InferenceWorker, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFParamsWidget logger = utils.LOGGER @@ -195,9 +196,17 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## # instance segmentation widgets self.instance_widgets = InstanceWidgets(parent=self) + self.crf_widgets = CRFParamsWidget(parent=self) self.use_instance_choice = ui.CheckBox( - "Run instance segmentation", func=self._toggle_display_instance + "Run instance segmentation", + func=self._toggle_display_instance, + parent=self, + ) + self.use_crf = ui.CheckBox( + "Use CRF post-processing", + func=self._toggle_display_crf, + parent=self, ) self.save_stats_to_csv_box = ui.CheckBox( @@ -307,6 +316,10 @@ def _toggle_display_thresh(self): self.thresholding_checkbox, self.thresholding_slider.container ) + def _toggle_display_crf(self): + """Shows the choices for CRF post-processing depending on whether :py:attr:`self.use_crf` is checked""" + ui.toggle_visibility(self.use_crf, self.crf_widgets) + def _toggle_display_instance(self): """Shows or hides the options for instance segmentation based on current user selection""" ui.toggle_visibility(self.use_instance_choice, self.instance_widgets) @@ -424,6 +437,8 @@ def _build(self): self.thresholding_slider.container, # thresholding self.use_instance_choice, self.instance_widgets, + self.use_crf, + self.crf_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -435,6 +450,7 @@ def _build(self): self.anisotropy_wdgt.container.setVisible(False) self.thresholding_slider.container.setVisible(False) self.instance_widgets.setVisible(False) + self.crf_widgets.setVisible(False) self.save_stats_to_csv_box.setVisible(False) post_proc_group.setLayout(post_proc_layout) @@ -588,6 +604,8 @@ def start(self): compute_stats=self.save_stats_to_csv_box.isChecked(), post_process_config=self.post_process_config, sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), ) ##################### ##################### @@ -737,7 +755,10 @@ def on_yield(self, result: InferenceResult): opacity=0.8, ) - if result.instance_labels is not None: + if ( + len(result.instance_labels) > 0 + and self.worker_config.post_process_config.instance.enabled + ): for i, labels in enumerate(result.instance_labels): # labels = result.instance_labels method_name = ( @@ -779,5 +800,12 @@ def on_yield(self, result: InferenceResult): # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) + + if result.crf_results is not None: + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 462ee450..868dd279 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -13,6 +13,7 @@ ToInstanceUtils, ToSemanticUtils, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFWidget from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { @@ -22,6 +23,7 @@ "Convert to instance labels": ToInstanceUtils, "Convert to semantic labels": ToSemanticUtils, "Threshold": ThresholdUtils, + "CRF": CRFWidget, } @@ -30,7 +32,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): super().__init__() self._viewer = viewer - attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh"] + attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh", "crf"] self._create_utils_widgets(attr_names) # self.crop = Cropping(self._viewer) @@ -54,8 +56,15 @@ def __init__(self, viewer: "napari.viewer.Viewer"): def _build(self): layout = QVBoxLayout() ui.add_widgets(layout, self.utils_widgets) - layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) - layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) + ui.GroupedWidget.create_single_widget_group( + "Utilities", + widget=self.utils_choice, + layout=layout, + alignment=ui.BOTT_AL, + ) + + # layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) + # layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 1f3fd4c1..7250fe78 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -139,6 +139,20 @@ class PostProcessConfig: instance: InstanceSegConfig = InstanceSegConfig() +@dataclass +class CRFConfig: + """ + Class to record params for CRF + """ + + sa: float = 10 + sb: float = 5 + sg: float = 1 + w1: float = 10 + w2: float = 5 + n_iters: int = 5 + + ################ # Inference configs @@ -198,6 +212,8 @@ class InferenceWorkerConfig: compute_stats: bool = False post_process_config: PostProcessConfig = PostProcessConfig() sliding_window_config: SlidingWindowConfig = SlidingWindowConfig() + use_crf: bool = False + crf_config: CRFConfig = CRFConfig() images_filepaths: str = None layer: napari.layers.Layer = None diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 4b560a6d..209b093e 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -58,6 +58,8 @@ """Alias for Qt.AlignmentFlag.AlignAbsolute, to use in addWidget""" BOTT_AL = Qt.AlignmentFlag.AlignBottom """Alias for Qt.AlignmentFlag.AlignBottom, to use in addWidget""" +TOP_AL = Qt.AlignmentFlag.AlignTop +"""Alias for Qt.AlignmentFlag.AlignTop, to use in addWidget""" ############### # colors dark_red = "#72071d" # crimson red @@ -66,6 +68,9 @@ napari_param_grey = "#414851" # napari parameters menu color (lighter gray) napari_param_darkgrey = "#202228" # napari default LineEdit color ############### +# dimensions for utils ScrollArea +UTILS_MAX_WIDTH = 300 +UTILS_MAX_HEIGHT = 500 logger = utils.LOGGER @@ -792,7 +797,7 @@ def layer_data(self): logger.warning("Please select a valid layer !") return - return self._viewer.layers[self.layer_name()].data + return self.layer().data class FilePathWidget(QWidget): # TODO include load as folder @@ -1278,12 +1283,20 @@ def set_layout(self): @classmethod def create_single_widget_group( - cls, title, widget, layout, l=7, t=20, r=7, b=11 + cls, + title, + widget, + layout, + l=7, + t=20, + r=7, + b=11, + alignment=LEFT_AL, ): group = cls(title, l, t, r, b) group.layout.addWidget(widget) group.setLayout(group.layout) - layout.addWidget(group) + layout.addWidget(group, alignment=alignment) def add_widgets(layout, widgets, alignment=LEFT_AL): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index c6c05ad1..171c20f0 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,11 +2,12 @@ from datetime import datetime from pathlib import Path +import napari import numpy as np from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian -from tifffile import imread as tfl_imread +from tifffile import imread, imwrite LOGGER = logging.getLogger(__name__) ############### @@ -21,6 +22,76 @@ """ +#################### +# viewer utils +def save_folder(results_path, folder_name, images, image_paths): + """ + Saves a list of images in a folder + + Args: + results_path: Path to the folder containing results + folder_name: Name of the folder containing results + images: List of images to save + image_paths: list of filenames of images + """ + results_folder = results_path / Path(folder_name) + results_folder.mkdir(exist_ok=False, parents=True) + + for file, image in zip(image_paths, images): + path = results_folder / Path(file).name + + imwrite( + path, + image, + ) + LOGGER.info(f"Saved processed folder as : {results_folder}") + + +def save_layer(results_path, image_name, image): + """ + Saves an image layer at the specified path + + Args: + results_path: path to folder containing result + image_name: image name for saving + image: data array containing image + + Returns: + + """ + path = str(results_path / Path(image_name)) # TODO flexible filetype + LOGGER.info(f"Saved as : {path}") + imwrite(path, image) + + +def show_result(viewer, layer, image, name): + """ + Adds layers to a viewer to show result to user + + Args: + viewer: viewer to add layer in + layer: original layer the operation was run on, to determine whether it should be an Image or Labels layer + image: the data array containing the image + name: name of the added layer + + Returns: + + """ + if isinstance(layer, napari.layers.Image): + LOGGER.debug("Added resulting image layer") + viewer.add_image(image, name=name) + elif isinstance(layer, napari.layers.Labels): + LOGGER.debug("Added resulting label layer") + viewer.add_labels(image, name=name) + else: + LOGGER.warning( + f"Results not shown, unsupported layer type {type(layer)}" + ) + + +#################### + + class Singleton(type): """ Singleton class that can only be instantiated once at a time, @@ -44,7 +115,7 @@ def __call__(cls, *args, **kwargs): # if filename == "tif": # return True # def read(self, data, **kwargs): -# return tfl_imread(data) +# return imread(data) # # def get_data(self, data): # return data, {} @@ -234,7 +305,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # logger.warning( + # LOGGER.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) @@ -468,9 +539,7 @@ def load_images(dir_or_path, filetype="", as_folder: bool = False): ) # images_original = dask_imread(filename_pattern_original) else: - images_original = tfl_imread( - filename_pattern_original - ) # tifffile imread + images_original = imread(filename_pattern_original) # tifffile imread return images_original From 8d301f0771959da74e65c02298db020bae89f228 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:37:33 +0200 Subject: [PATCH 089/577] Fix erroneous test comment, added toggle for crf - Warn if crf not installed - Fix test --- napari_cellseg3d/_tests/test_utils.py | 2 +- napari_cellseg3d/code_plugins/plugin_crf.py | 22 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index f2a9d32c..0b28183d 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -88,7 +88,7 @@ def test_get_padding_dim(): # "The padding value is currently 2048." # ) # - # pad = utils.get_padding_dim(size) + pad = utils.get_padding_dim(size) # # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 3dbd47bb..cbdacf3a 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -7,7 +7,11 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config +from napari_cellseg3d.code_models.crf import ( + CRF_INSTALLED, + CRFWorker, + crf_with_config, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.utils import LOGGER as logger @@ -43,6 +47,17 @@ def __init__(self, parent=None): self._set_tooltips() def _build(self): + if not CRF_INSTALLED: + ui.add_widgets( + self.layout, + [ + ui.make_label( + "ERROR: CRF not installed.\nPlease refer to the documentation to install it." + ), + ], + ) + self.set_layout() + return ui.add_widgets( self.layout, [ @@ -113,7 +128,10 @@ def __init__(self, viewer, parent=None): napari.layers.Image ) # to load all crf-compatible inputs, not int only self.image_layer_loader.setVisible(True) - self.start_button.setVisible(True) + if CRF_INSTALLED: + self.start_button.setVisible(True) + else: + self.start_button.setVisible(False) self.result_layer = None self.result_name = None From 398b72e4758e0870b6915c6233cb83474d97a2bf Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:56:08 +0200 Subject: [PATCH 090/577] Specify missing test deps --- pyproject.toml | 3 ++- tox.ini | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 21c40a6f..318f3331 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ -"git+https://github.com/lucasb-eyer/pydensecrf.git", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -82,4 +82,5 @@ test = [ "coverage", "tox", "twine", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 030d7437..f19da4a8 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = magicgui pytest-qt qtpy + "git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] ; opencv-python From 418a8f9f7538cc0c2cba6ddb4a3b38d1f13e630d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:02:31 +0200 Subject: [PATCH 091/577] Trying to fix deps on Git --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 318f3331..1c0a7ba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] dev = [ "isort", @@ -82,5 +82,5 @@ test = [ "coverage", "tox", "twine", - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] diff --git a/tox.ini b/tox.ini index f19da4a8..4598508a 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - "git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" ; pyopencl[pocl] ; opencv-python From 0090f09699f2d286a405c434b976e7480f790761 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:04:33 +0200 Subject: [PATCH 092/577] Removed master link to pydensecrf --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c0a7ba8..36ac3c3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -82,5 +82,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 4598508a..b2d014d2 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] ; opencv-python From 1187507b7353f6cb56d688bdca638ee5ad842324 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:07:23 +0200 Subject: [PATCH 093/577] Use commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 36ac3c3d..6f02edda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] dev = [ "isort", @@ -82,5 +82,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] diff --git a/tox.ini b/tox.ini index b2d014d2..719ab398 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" ; pyopencl[pocl] ; opencv-python From 935eb92e464532559f0028ebb35655b2c7cf1fd8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:09:27 +0200 Subject: [PATCH 094/577] Removed commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f02edda..d24f4b65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] dev = [ "isort", @@ -82,5 +82,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] diff --git a/tox.ini b/tox.ini index 719ab398..192cf82b 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" ; pyopencl[pocl] ; opencv-python From 008bca1bdaf5a967f34b92f914aa80e6d13672ba Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:11:27 +0200 Subject: [PATCH 095/577] Removed master --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d24f4b65..7bb19dcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] dev = [ "isort", @@ -82,5 +82,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] diff --git a/tox.ini b/tox.ini index 192cf82b..8f04f6d4 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" ; pyopencl[pocl] ; opencv-python From 5da7d482bd8aad0cb5e9c86b0cb608e86a8de71e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:17:16 +0200 Subject: [PATCH 096/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 8f04f6d4..abdbfd63 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" + pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] ; opencv-python From c1fa2a4af5a971b1757eb5a7438fa8bd75989968 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 09:06:23 +0200 Subject: [PATCH 097/577] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7bb19dcb..547047ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] dev = [ "isort", @@ -82,5 +82,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] From e22b9152788ce5ed826f639d111dab3f05a51142 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 098/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- docs/res/code/plugin_convert.rst | 5 - .../_tests/test_plugin_inference.py | 1 + .../code_models/model_instance_seg.py | 276 +++++++++++++++++- napari_cellseg3d/code_models/model_workers.py | 32 +- .../code_plugins/plugin_convert.py | 155 +--------- .../code_plugins/plugin_model_inference.py | 25 +- napari_cellseg3d/config.py | 38 +-- napari_cellseg3d/interface.py | 21 +- requirements.txt | 1 + 9 files changed, 306 insertions(+), 248 deletions(-) diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index c7dc7df9..03944510 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -19,11 +19,6 @@ ToSemanticUtils .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToSemanticUtils :members: __init__ -InstanceWidgets -********************************** -.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::InstanceWidgets - :members: __init__, run_method - ToInstanceUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToInstanceUtils diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..584ffd3b 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,6 +8,7 @@ from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e4bec4ea..2cb7728b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -6,20 +6,65 @@ import numpy as np -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area +import pyclesperanto_prototype as cle +from qtpy.QtWidgets import QWidget + from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed + +from skimage.filters import thresholding +from skimage.transform import resize + +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes from tifffile import imread +from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from napari_cellseg3d.utils import sphericity_volume_area +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -50,18 +95,43 @@ def get_dict(self): def threshold(volume, thresh): + """Remove all values smaller than the specified threshold in the volume""" im = np.squeeze(volume) binary = im > thresh return np.where(binary, im, np.zeros_like(im)) +def voronoi_otsu( + volume: np.ndarray, + spot_sigma: float, + outline_sigma: float, + remove_small_size: float, +): + """ + Voronoi-Otsu labeling from pyclesperanto. + BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase + https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: + volume (np.ndarray): volume to segment + spot_sigma (float): parameter determining how close detected objects can be + outline_sigma (float): determines the smoothness of the segmentation + remove_small_size (float): remove all objects smaller than the specified size in pixels + + Returns: + Instance segmentation labels from Voronoi-Otsu method + """ + semantic = np.squeeze(volume) + instance = cle.voronoi_otsu_labeling( + semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + ) + # instance = remove_small_objects(instance, remove_small_size) + return instance + + def binary_connected( volume, thres=0.5, thres_small=3, - # scale_factors=(1.0, 1.0, 1.0), - *args, - **kwargs, ): r"""Convert binary foreground probability maps to instance masks via connected-component labeling. @@ -70,7 +140,6 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) """ semantic = np.squeeze(volume) foreground = semantic > thres # int(255 * thres) @@ -97,12 +166,9 @@ def binary_connected( def binary_watershed( volume, thres_objects=0.3, - thres_small=10, thres_seeding=0.9, - # scale_factors=(1.0, 1.0, 1.0), + thres_small=10, rem_seed_thres=3, - *args, - **kwargs, ): r"""Convert binary foreground probability maps to instance masks via watershed segmentation algorithm. @@ -113,10 +179,9 @@ def binary_watershed( Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. - thres_seeding (float): threshold for seeding. Default: 0.98 thres_objects (float): threshold for foreground objects. Default: 0.3 + thres_seeding (float): threshold for seeding. Default: 0.9 thres_small (int): size threshold of small objects removal. Default: 10 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) rem_seed_thres (int): threshold for small seeds removal. Default : 3 """ semantic = np.squeeze(volume) @@ -193,7 +258,7 @@ def to_instance(image, is_file_path=False): result = binary_watershed( image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 - ) # TODO add params + ) # FIXME add params from utils plugin return result @@ -283,3 +348,188 @@ def fill(lst, n=len(properties) - 1): ratio, fill([len(properties)]), ) + + +class Watershed(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Watershed", + function=binary_watershed, + num_sliders=2, + num_counters=2, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(50) + + self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].tooltips = "Probability threshold for seeding" + self.sliders[1].setValue(90) + + self.counters[0].label.setText("Small object removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(30) + + self.counters[1].label.setText("Small seed removal") + self.counters[1].tooltips = ( + "Volume/size threshold for small seeds removal." + "\nAll seeds with a volume/size below this value will be removed." + ) + self.counters[1].setValue(3) + + def run_method(self, image): + return self.function( + image, + self.sliders[0].value(), + self.sliders[1].value(), + self.counters[0].value(), + self.counters[1].value(), + ) + + +class ConnectedComponents(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Connected Components", + function=binary_connected, + num_sliders=1, + num_counters=1, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(80) + + self.counters[0].label.setText("Small objects removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(3) + + def run_method(self, image): + return self.function( + image, self.sliders[0].value(), self.counters[0].value() + ) + + +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + def __init__(self): + super().__init__( + name="Voronoi-Otsu", + function=voronoi_otsu, + num_sliders=0, + num_counters=3, + ) + self.counters[0].label.setText("Spot sigma") + self.counters[ + 0 + ].tooltips = "Determines how close detected objects can be" + self.counters[0].setMaximum(100) + self.counters[0].setValue(2) + + self.counters[1].label.setText("Outline sigma") + self.counters[ + 1 + ].tooltips = "Determines the smoothness of the segmentation" + self.counters[1].setMaximum(100) + self.counters[1].setValue(2) + + self.counters[2].label.setText("Small object removal") + self.counters[2].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + + def run_method(self, image): + return self.function( + image, + self.counters[0].value(), + self.counters[1].value(), + self.counters[2].value(), + ) + + +class InstanceWidgets(QWidget): + """ + Base widget with several sliders, for use in instance segmentation parameters + """ + + def __init__(self, parent=None): + """ + Creates an InstanceWidgets widget + + Args: + parent: parent widget + """ + super().__init__(parent) + + self.method_choice = ui.DropdownMenu( + INSTANCE_SEGMENTATION_METHOD_LIST.keys() + ) + self.methods = [] + self.instance_widgets = {} + + self.method_choice.currentTextChanged.connect(self._set_visibility) + self._build() + + def _build(self): + + group = ui.GroupedWidget("Instance segmentation") + group.layout.addWidget(self.method_choice) + + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + self.instance_widgets[name] = [] + if len(method().sliders) > 0: + for slider in method().sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method().counters) > 0: + for counter in method().counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + + self.setLayout(group.layout) + self._set_visibility() + + def _set_visibility(self): + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + + for widget in self.instance_widgets[method.name]: + widget.set_visibility(True) + + for key in self.instance_widgets.keys(): + if key != method.name: + for widget in self.instance_widgets[key]: + widget.set_visibility(False) + + def run_method(self, volume): + """ + Calls instance function with chosen parameters + Args: + volume: image data to run method on + + Returns: processed image from self._method + """ + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + return method.run_method(volume) + + +INSTANCE_SEGMENTATION_METHOD_LIST = { + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, +} diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index c5675a11..0fede76e 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -50,12 +50,6 @@ from napari_cellseg3d import config from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_connected, -) -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_watershed, -) from napari_cellseg3d.code_models.model_instance_seg import ImageStats from napari_cellseg3d.code_models.model_instance_seg import volume_stats @@ -604,30 +598,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - threshold = ( - self.config.post_process_config.instance.threshold.threshold_value - ) - size_small = ( - self.config.post_process_config.instance.small_object_removal_threshold.threshold_value - ) - method_name = self.config.post_process_config.instance.method - - if method_name == "Watershed": # FIXME use dict in config instead - - def method(image): - return binary_watershed(image, threshold, size_small) - - elif method_name == "Connected components": - - def method(image): - return binary_connected(image, threshold, size_small) - - else: - raise NotImplementedError( - "Selected instance segmentation method is not defined" - ) - - instance_labels = method(to_instance) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 5560b7b9..f461b46f 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -16,6 +16,7 @@ ) from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -358,160 +359,6 @@ def _start(self): self.images_filepaths, ) - -class InstanceWidgets(QWidget): - """ - Base widget with several sliders, for use in instance segmentation parameters - """ - - def __init__(self, parent=None): - """ - Creates an InstanceWidgets widget - - Args: - parent: parent widget - """ - super().__init__(parent) - - self.method_choice = ui.DropdownMenu( - config.INSTANCE_SEGMENTATION_METHOD_LIST.keys() - ) - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ] - - self.method_choice.currentTextChanged.connect(self._show_connected) - self.method_choice.currentTextChanged.connect(self._show_watershed) - - self.threshold_slider1 = ui.Slider( - lower=0, - upper=100, - default=50, - divide_factor=100.0, - step=5, - text_label="Probability threshold :", - ) - """Base prob. threshold""" - self.threshold_slider2 = ui.Slider( - lower=0, - upper=100, - default=90, - divide_factor=100.0, - step=5, - text_label="Probability threshold (seeding) :", - ) - """Second prob. thresh. (seeding)""" - - self.counter1 = ui.IntIncrementCounter( - upper=100, - default=10, - step=5, - label="Small object removal (pxs) :", - ) - """Small obj. rem.""" - - self.counter2 = ui.IntIncrementCounter( - upper=100, - default=3, - step=5, - label="Small seed removal (pxs) :", - ) - """Small seed rem.""" - - self._build() - - def run_method(self, volume): - """ - Calls instance function with chosen parameters - Args: - volume: image data to run method on - - Returns: processed image from self._method - """ - return self._method( - volume, - self.threshold_slider1.slider_value, - self.counter1.value(), - self.threshold_slider2.slider_value, - self.counter2.value(), - ) - - def _build(self): - group = ui.GroupedWidget("Instance segmentation") - - ui.add_widgets( - group.layout, - [ - self.method_choice, - self.threshold_slider1.container, - self.threshold_slider2.container, - self.counter1.label, - self.counter1, - self.counter2.label, - self.counter2, - ], - ) - - self.setLayout(group.layout) - self._set_tooltips() - - def _set_tooltips(self): - self.method_choice.setToolTip( - "Choose which method to use for instance segmentation" - "\nConnected components : all separated objects will be assigned an unique ID. " - "Robust but will not work correctly with adjacent/touching objects\n" - "Watershed : assigns objects ID based on the probability gradient surrounding an object. " - "Requires the model to surround objects in a gradient;" - " can possibly correctly separate unique but touching/adjacent objects." - ) - self.threshold_slider1.tooltips = ( - "All objects below this probability will be ignored (set to 0)" - ) - self.counter1.setToolTip( - "Will remove all objects smaller (in volume) than the specified number of pixels" - ) - self.threshold_slider2.tooltips = ( - "All seeds below this probability will be ignored (set to 0)" - ) - self.counter2.setToolTip( - "Will remove all seeds smaller (in volume) than the specified number of pixels" - ) - - def _show_watershed(self): - name = "Watershed" - if self.method_choice.currentText() == name: - self._show_slider1() - self._show_slider2() - self._show_counter1() - self._show_counter2() - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_connected(self): - name = "Connected components" - if self.method_choice.currentText() == name: - self._show_slider1() - self._show_slider2(False) - self._show_counter1() - self._show_counter2(False) - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_slider1(self, is_visible: bool = True): - self.threshold_slider1.container.setVisible(is_visible) - - def _show_slider2(self, is_visible: bool = True): - self.threshold_slider2.container.setVisible(is_visible) - - def _show_counter1(self, is_visible: bool = True): - self.counter1.setVisible(is_visible) - self.counter1.label.setVisible(is_visible) - - def _show_counter2(self, is_visible: bool = True): - self.counter2.setVisible(is_visible) - self.counter2.label.setVisible(is_visible) - - class ToInstanceUtils(BasePluginFolder): """ Widget to convert semantic labels to instance labels diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 0dca3ec8..4a7ab671 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -12,7 +12,11 @@ from napari_cellseg3d.code_models.model_framework import ModelFramework from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_plugins.plugin_convert import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -77,9 +81,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): config.InferenceWorkerConfig() ) """InferenceWorkerConfig class from config.py""" - self.instance_config: config.InstanceSegConfig = ( - config.InstanceSegConfig() - ) + self.instance_config: InstanceMethod """InstanceSegConfig class from config.py""" self.post_process_config: config.PostProcessConfig = ( config.PostProcessConfig() @@ -551,18 +553,9 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - instance_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.threshold_slider1.slider_value - ) - instance_small_object_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.counter1.value() - ) - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.method_choice.currentText(), - threshold=instance_thresh_config, - small_object_removal_threshold=instance_small_object_thresh_config, - ) + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] self.post_process_config = config.PostProcessConfig( zoom=zoom_config, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 57c65bac..74cbf81d 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,22 +8,20 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_connected, -) -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_watershed, -) -from napari_cellseg3d.code_models.models import ( - model_SegResNet as SegResNet, -) -from napari_cellseg3d.code_models.models import ( - model_SwinUNetR as SwinUNetR, -) + +# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet +from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR from napari_cellseg3d.code_models.models import ( model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -40,10 +38,6 @@ # "test" : DO NOT USE, reserved for testing } -INSTANCE_SEGMENTATION_METHOD_LIST = { - "Watershed": binary_watershed, - "Connected components": binary_connected, -} WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") @@ -127,21 +121,11 @@ class Zoom: zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: str = None - threshold: Thresholding = Thresholding(enabled=False, threshold_value=0.85) - small_object_removal_threshold: Thresholding = Thresholding( - enabled=True, threshold_value=20 - ) - - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d23199ee..d2f8d787 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -12,6 +12,7 @@ # from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QObject from qtpy.QtCore import Qt +# from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices @@ -499,9 +500,12 @@ def __init__( self._build_container() - def _build_container(self): - self.container.layout + def set_visibility(self, visible: bool): + self.container.setVisible(visible) + self.setVisible(visible) + self.text_label.setVisible(visible) + def _build_container(self): if self.text_label is not None: add_widgets( self.container.layout, @@ -1021,7 +1025,7 @@ class DoubleIncrementCounter(QDoubleSpinBox): def __init__( self, lower: Optional[float] = 0.0, - upper: Optional[float] = 10.0, + upper: Optional[float] = 1000.0, default: Optional[float] = 0.0, step: Optional[float] = 1.0, parent: Optional[QWidget] = None, @@ -1045,6 +1049,13 @@ def __init__( if label is not None: self.label = make_label(name=label) + self.valueChanged.connect(self._update_step) + + def _update_step(self): + if self.value() < 0.9: + self.setSingleStep(0.1) + else: + self.setSingleStep(1) @property def tooltips(self): @@ -1081,6 +1092,10 @@ def make_n( cls, n, lower, upper, default, step, parent, fixed ) + def set_visibility(self, visible: bool): + self.setVisible(visible) + self.label.setVisible(visible) + class IntIncrementCounter(QSpinBox): """Class implementing a number counter with increments (spin box) for int.""" diff --git a/requirements.txt b/requirements.txt index f97de33c..7d2f12f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 7a2e4916c94e67ca180af3a561134281b8b80303 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 099/577] Disabled small removal in Voronoi-Otsu --- .../code_models/model_instance_seg.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 2cb7728b..81b1744b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -105,7 +105,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - remove_small_size: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -115,11 +115,12 @@ def voronoi_otsu( volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - remove_small_size (float): remove all objects smaller than the specified size in pixels + Returns: Instance segmentation labels from Voronoi-Otsu method """ + # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma @@ -427,7 +428,7 @@ def __init__(self): name="Voronoi-Otsu", function=voronoi_otsu, num_sliders=0, - num_counters=3, + num_counters=2, ) self.counters[0].label.setText("Spot sigma") self.counters[ @@ -443,18 +444,19 @@ def __init__(self): self.counters[1].setMaximum(100) self.counters[1].setValue(2) - self.counters[2].label.setText("Small object removal") - self.counters[2].tooltips = ( - "Volume/size threshold for small object removal." - "\nAll objects with a volume/size below this value will be removed." - ) + # self.counters[2].label.setText("Small object removal") + # self.counters[2].tooltips = ( + # "Volume/size threshold for small object removal." + # "\nAll objects with a volume/size below this value will be removed." + # ) + # self.counters[2].setValue(30) def run_method(self, image): return self.function( image, self.counters[0].value(), self.counters[1].value(), - self.counters[2].value(), + # self.counters[2].value(), ) From 9b57ccd43c67af3636db94a6b3589a445f78cffb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 100/577] Added new docs for instance seg --- docs/res/code/model_instance_seg.rst | 23 +++++++++++++++++++ .../code_models/model_instance_seg.py | 22 ++++++++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst index e4146ec1..3b323173 100644 --- a/docs/res/code/model_instance_seg.rst +++ b/docs/res/code/model_instance_seg.rst @@ -1,6 +1,29 @@ model_instance_seg.py =========================================== +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu + :members: __init__ + Functions ------------- diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 81b1744b..7fd33317 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -37,6 +37,14 @@ def __init__( num_sliders: int, num_counters: int, ): + """ + Methods for instance segmentation + Args: + name: Name of the instance segmentation method (for UI) + function: Function to use for instance segmentation + num_sliders: Number of Slider UI elements needed to set the parameters of the function + num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + """ self.name = name self.function = function self.counters: List[ui.DoubleIncrementCounter] = [] @@ -176,7 +184,7 @@ def binary_watershed( Note: This function uses the `skimage.segmentation.watershed `_ - function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays. + function that converts the input image into ``np.float64`` data type for processing. Therefore, please make sure enough memory is allocated when handling large arrays. Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. @@ -352,6 +360,8 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod, metaclass=Singleton): + """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" + def __init__(self): super().__init__( name="Watershed", @@ -395,6 +405,8 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod, metaclass=Singleton): + """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" + def __init__(self): super().__init__( name="Connected Components", @@ -423,6 +435,8 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" + def __init__(self): super().__init__( name="Voronoi-Otsu", @@ -430,14 +444,14 @@ def __init__(self): num_sliders=0, num_counters=2, ) - self.counters[0].label.setText("Spot sigma") + self.counters[0].label.setText("Spot sigma") # closeness self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -531,7 +545,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { + VoronoiOtsu().name: VoronoiOtsu, Watershed().name: Watershed, ConnectedComponents().name: ConnectedComponents, - VoronoiOtsu().name: VoronoiOtsu, } From 2a645b7c9f98f2c6b6b646b4c578b14c4f9a3fb8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 09:50:45 +0100 Subject: [PATCH 101/577] Docs + UI update - Updated welcome/README - Changed step for DoubleCounter --- README.md | 5 +++-- docs/res/welcome.rst | 15 +++++++++------ napari_cellseg3d/interface.py | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 011f072c..d6037a0d 100644 --- a/README.md +++ b/README.md @@ -143,8 +143,9 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). - +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). +Please refer to the documentation for full acknowledgements. ## Plugin base This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 6832e71e..d2f2c0f0 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -90,20 +90,23 @@ We also provide a model that was trained in-house on mesoSPIM nuclei data in col This plugin mainly uses the following libraries and software: -* `napari website`_ +* `napari`_ -* `PyTorch website`_ +* `PyTorch`_ -* `MONAI project website`_ (various models used here are credited `on their website`_) +* `MONAI project`_ (various models used here are credited `on their website`_) + +* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ .. _Wyss Center: https://wysscenter.ch/ .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP -.. _napari website: https://napari.org/ -.. _PyTorch website: https://pytorch.org/ -.. _MONAI project website: https://monai.io/ +.. _napari: https://napari.org/ +.. _PyTorch: https://pytorch.org/ +.. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets +.. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype .. rubric:: References diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d2f8d787..136da3e1 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1053,9 +1053,9 @@ def __init__( def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.1) + self.setSingleStep(0.01) else: - self.setSingleStep(1) + self.setSingleStep(0.1) @property def tooltips(self): From c256337443aff9124660b7228f094b694c37f4ff Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:07:33 +0100 Subject: [PATCH 102/577] Update requirements.txt Fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7d2f12f7..93da070f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pyclesperanto-prototype>=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From fa0bef33e8e197b8a1dd720b15d2356927eae16f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 103/577] isort --- napari_cellseg3d/code_models/model_instance_seg.py | 12 ++---------- napari_cellseg3d/code_plugins/plugin_convert.py | 9 ++------- .../code_plugins/plugin_model_inference.py | 8 ++++---- napari_cellseg3d/config.py | 11 +++++------ napari_cellseg3d/interface.py | 3 ++- 5 files changed, 15 insertions(+), 28 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 7fd33317..7a5f097b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,30 +1,22 @@ from __future__ import division from __future__ import print_function - from dataclasses import dataclass from typing import List - import numpy as np - import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget - from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - -from skimage.filters import thresholding -from skimage.transform import resize - +from tifffile import imread # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes -from tifffile import imread from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index f461b46f..6432d761 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,22 +1,17 @@ import warnings from pathlib import Path - import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QWidget from tifffile import imread from tifffile import imwrite import napari_cellseg3d.interface as ui -from napari_cellseg3d import config from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - clear_small_objects, -) +from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 4a7ab671..2420829e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,13 +10,13 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import ( INSTANCE_SEGMENTATION_METHOD_LIST, ) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 74cbf81d..e665d28c 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,6 +8,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -16,12 +21,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 136da3e1..a854905b 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -12,7 +12,8 @@ # from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QObject from qtpy.QtCore import Qt -# from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices From 4b4ae5228772531314753a2b48f2d5f2d43eedb6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 104/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 napari_cellseg3d/_tests/pytest.ini diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini new file mode 100644 index 00000000..814cca2e --- /dev/null +++ b/napari_cellseg3d/_tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +qt_api=pyqt5 \ No newline at end of file From 81a66e4ed296648e2877095352e23f9cca9ead8d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 105/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- .../code_models/model_instance_seg.py | 84 +++++++++++-------- .../code_plugins/plugin_convert.py | 2 +- .../code_plugins/plugin_model_inference.py | 2 +- 3 files changed, 49 insertions(+), 39 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 7a5f097b..57065971 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,11 +15,16 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from napari_cellseg3d.utils import sphericity_volume_area +# list of methods : +WATERSHED = "Watershed" +CONNECTED_COMP = "Connected Components" +VORONOI_OTSU = "Voronoi-Otsu" + class InstanceMethod: def __init__( @@ -28,6 +33,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, + widget_parent: QWidget = None ): """ Methods for instance segmentation @@ -36,6 +42,7 @@ def __init__( function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + widget_parent: parent for the declared widgets """ self.name = name self.function = function @@ -47,7 +54,7 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), ) self.sliders.append(getattr(self, widget)) @@ -57,7 +64,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label=""), + ui.DoubleIncrementCounter(label="", parent=None), ) self.counters.append(getattr(self, widget)) @@ -351,15 +358,16 @@ def fill(lst, n=len(properties) - 1): ) -class Watershed(InstanceMethod, metaclass=Singleton): +class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Watershed", + name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -396,15 +404,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): +class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Connected Components", + name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -426,15 +435,16 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): +class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self): + def __init__(self, widget_parent): super().__init__( - name="Voronoi-Otsu", + name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, num_counters=2, + widget_parent=widget_parent ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ @@ -479,7 +489,6 @@ def __init__(self, parent=None): parent: parent widget """ super().__init__(parent) - self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) @@ -490,37 +499,38 @@ def __init__(self, parent=None): self._build() def _build(self): - group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - self.instance_widgets[name] = [] - if len(method().sliders) > 0: - for slider in method().sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method().counters) > 0: - for counter in method().counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) + try: + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + method_class = method(widget_parent=self.parent()) + self.instance_widgets[name] = [] + # moderately unsafe way to init those widgets + if len(method_class.sliders) > 0: + for slider in method_class.sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method_class.counters) > 0: + for counter in method_class.counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + except RuntimeError as e: + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() - - for widget in self.instance_widgets[method.name]: - widget.set_visibility(True) - for key in self.instance_widgets.keys(): - if key != method.name: - for widget in self.instance_widgets[key]: + for name in self.instance_widgets.keys(): + if name != self.method_choice.currentText(): + for widget in self.instance_widgets[name]: widget.set_visibility(False) + else: + for widget in self.instance_widgets[name]: + widget.set_visibility(True) def run_method(self, volume): """ @@ -537,7 +547,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VoronoiOtsu().name: VoronoiOtsu, - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6432d761..7a59dcf0 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -376,7 +376,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.data_panel = self._build_io_panel() self.label_layer_loader.set_layer_type(napari.layers.Layer) - self.instance_widgets = InstanceWidgets() + self.instance_widgets = InstanceWidgets(parent=self) self.start_btn = ui.Button("Start", self._start) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 2420829e..5ad8fc3e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -191,7 +191,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## ################## # instance segmentation widgets - self.instance_widgets = InstanceWidgets(self) + self.instance_widgets = InstanceWidgets(parent=self) self.use_instance_choice = ui.CheckBox( "Run instance segmentation", func=self._toggle_display_instance From 2e5c6e71f6e9c09a4f6eee3111ff7a2419c4c849 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 106/577] Fix inference --- .../code_models/model_instance_seg.py | 5 +- napari_cellseg3d/code_models/model_workers.py | 12 ++--- .../code_plugins/plugin_model_inference.py | 13 ++--- napari_cellseg3d/config.py | 6 ++- notebooks/assess_instance.ipynb | 50 +++++++++++++++++++ 5 files changed, 71 insertions(+), 15 deletions(-) create mode 100644 notebooks/assess_instance.ipynb diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 57065971..667b8bc3 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -492,8 +492,10 @@ def __init__(self, parent=None): self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = [] + self.methods = {} + """Contains the instance of the method, with its name as key""" self.instance_widgets = {} + """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() @@ -505,6 +507,7 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class self.instance_widgets[name] = [] # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 0fede76e..47f63fec 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -309,9 +309,7 @@ def log_parameters(self): instance_config = config.post_process_config.instance if instance_config.enabled: self.log( - f"Instance segmentation enabled, method : {instance_config.method}\n" - f"Probability threshold is {instance_config.threshold.threshold_value:.2f}\n" - f"Objects smaller than {instance_config.small_object_removal_threshold.threshold_value} pixels will be removed\n" + f"Instance segmentation enabled, method : {instance_config.method.name}\n" ) self.log("-" * 20) @@ -383,7 +381,7 @@ def load_folder(self): return inference_loader def load_layer(self): - self.log("Loading layer\n") + self.log("\nLoading layer\n") data = np.squeeze(self.config.layer) volume = np.array(data, dtype=np.int16) @@ -545,7 +543,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -598,8 +596,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 5ad8fc3e..a6a90eb4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -553,9 +553,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -723,13 +724,13 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method = self.worker_config.post_process_config.instance.method + method_name = self.worker_config.post_process_config.instance.method.name number_cells = ( np.unique(labels.flatten()).size - 1 ) # remove background - name = f"({number_cells} objects)_{method}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" viewer.add_labels(labels, name=name) @@ -743,7 +744,7 @@ def on_yield(self, result: InferenceResult): f"Number of instances : {stats.number_objects}" ) - csv_name = f"/{method}_seg_results_{image_id}_{utils.get_date_time()}.csv" + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" stats_df.to_csv( self.worker_config.results_path + csv_name, index=False, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index e665d28c..107af8e6 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -119,12 +119,16 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb new file mode 100644 index 00000000..40412282 --- /dev/null +++ b/notebooks/assess_instance.ipynb @@ -0,0 +1,50 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tifffile import imread\n", + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From e5d77b4786e9b72472516434981c6f2790c6da83 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 107/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/model_instance_seg.py | 9 +- napari_cellseg3d/code_plugins/plugin_crop.py | 4 +- .../code_plugins/plugin_utilities.py | 7 +- .../dev_scripts/artefact_labeling.py | 421 ++++++++++++++++++ .../dev_scripts/correct_labels.py | 320 +++++++++++++ .../dev_scripts/evaluate_labels.py | 276 ++++++++++++ notebooks/assess_instance.ipynb | 401 ++++++++++++++++- 7 files changed, 1420 insertions(+), 18 deletions(-) create mode 100644 napari_cellseg3d/dev_scripts/artefact_labeling.py create mode 100644 napari_cellseg3d/dev_scripts/correct_labels.py create mode 100644 napari_cellseg3d/dev_scripts/evaluate_labels.py diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 667b8bc3..a8bb240b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -37,12 +37,14 @@ def __init__( ): """ Methods for instance segmentation + Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -118,14 +120,15 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - Returns: Instance segmentation labels from Voronoi-Otsu method + """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) @@ -191,6 +194,7 @@ def binary_watershed( thres_seeding (float): threshold for seeding. Default: 0.9 thres_small (int): size threshold of small objects removal. Default: 10 rem_seed_thres (int): threshold for small seeds removal. Default : 3 + """ semantic = np.squeeze(volume) seed_map = semantic > thres_seeding @@ -487,6 +491,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -538,10 +543,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = INSTANCE_SEGMENTATION_METHOD_LIST[ self.method_choice.currentText() diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 97485bf4..406ae7e7 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -177,8 +177,8 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 400]) - self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Expanding) + ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() # def _check_results_path(self, folder): diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 9e66213f..1f0d598b 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -60,10 +60,10 @@ def _build(self): layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) - layout.setSizeConstraint(QLayout.SetFixedSize) + # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) - self.setMinimumHeight(1000) - self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) + # self.setMinimumHeight(2000) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._update_visibility() def _create_utils_widgets(self, names): @@ -79,7 +79,6 @@ def _create_utils_widgets(self, names): raise RuntimeError( "One or several utility widgets are missing/erroneous" ) - # TODO how to auto-update list based on UTILITIES_WIDGETS ? def _update_visibility(self): widget_class = UTILITIES_WIDGETS[self.utils_choice.currentText()] diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py new file mode 100644 index 00000000..875ca9b6 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -0,0 +1,421 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +from pathlib import Path +import scipy.ndimage as ndimage +import os +import napari +# import sys +# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from skimage.filters import threshold_otsu + +""" +New code by Yves Paychere +Creates labels of artifacts in an image based on existing labels of neurons +""" + + +def map_labels(labels, artefacts): + """Map the artefacts labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + artefacts : ndarray + Label image with artefacts labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the artefact and the label value of the neurone associated or the neurons associated + new_labels: list + The labels of the artefacts that are not labelled in the neurons + """ + map_labels_existing = [] + new_labels = [] + + for i in np.unique(artefacts): + if i == 0: + continue + indexes = labels[artefacts == i] + # find the most common label in the indexes + unique, counts = np.unique(indexes, return_counts=True) + unique = np.flip(unique[np.argsort(counts)]) + counts = np.flip(counts[np.argsort(counts)]) + if unique[0] != 0: + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + elif ( + counts[0] < np.sum(counts) * 2 / 3.0 + ): # the artefact is connected to multiple neurons + total = 0 + ii = 1 + while total < np.size(indexes) / 3.0: + total = np.sum(counts[1 : ii + 1]) + ii += 1 + map_labels_existing.append(np.append([i], unique[1 : ii + 1])) + else: + new_labels.append(i) + + return map_labels_existing, new_labels + + +def make_labels( + path_image, + path_labels_out, + threshold_factor=1, + threshold_size=30, + label_value=1, + do_multi_label=True, + use_watershed=True, + augment_contrast_factor=2, +): + """Detect nucleus. using a binary watershed algorithm and otsu thresholding. + Parameters + ---------- + path_image : str + Path to image. + path_labels_out : str + Path of the output labelled image. + threshold_size : int, optional + Threshold for nucleus size, if the nucleus is smaller than this value it will be removed. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different nucleus will be labelled as a different value. + use_watershed : bool, optional + If True, use watershed algorithm to detect nucleus. + augment_contrast_factor : int, optional + Factor to augment the contrast of the image. + Returns + ------- + ndarray + Label image with nucleus labelled with 1 value per nucleus. + """ + + image = imread(path_image) + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + threshold_brightness = threshold_otsu(image) * threshold_factor + image_contrasted = np.where(image > threshold_brightness, image, 0) + + if use_watershed: + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = image_contrasted * augment_contrast_factor + image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) + labels = binary_watershed(image_contrasted, thres_small=threshold_size) + else: + labels = ndimage.label(image_contrasted)[0] + + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + + if not do_multi_label: + labels = np.where(labels > 0, label_value, 0) + + imwrite(path_labels_out, labels.astype(np.uint16)) + imwrite( + path_labels_out.replace(".tif", "_contrast.tif"), + image_contrasted.astype(np.float32), + ) + + +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): + """Select image by labels. + Parameters + ---------- + path_image : str + Path to image. + path_labels : str + Path to labels. + path_image_out : str + Path of the output image. + label_values : list + List of label values to select. + """ + image = imread(path_image) + labels = imread(path_labels) + image = np.where(np.isin(labels, label_values), image, 0) + imwrite(path_image_out, image.astype(np.float32)) + + +# select the smalles cube that contains all the none zero pixel of an 3d image +def get_bounding_box(img): + height = np.any(img, axis=(0, 1)) + rows = np.any(img, axis=(0, 2)) + cols = np.any(img, axis=(1, 2)) + + xmin, xmax = np.where(cols)[0][[0, -1]] + ymin, ymax = np.where(rows)[0][[0, -1]] + zmin, zmax = np.where(height)[0][[0, -1]] + return xmin, xmax, ymin, ymax, zmin, zmax + + +# crop the image +def crop_image(img): + xmin, xmax, ymin, ymax, zmin, zmax = get_bounding_box(img) + return img[xmin:xmax, ymin:ymax, zmin:zmax] + + +def crop_image_path(path_image, path_image_out): + """Crop image. + Parameters + ---------- + path_image : str + Path to image. + path_image_out : str + Path of the output image. + """ + image = imread(path_image) + image = crop_image(image) + imwrite(path_image_out, image.astype(np.float32)) + + +def make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, + label_value=2, + do_multi_label=False, + remove_true_labels=True, +): + """Detect pseudo nucleus. + Parameters + ---------- + image : ndarray + Image. + labels : ndarray + Label image. + threshold_artefact_brightness_percent : int, optional + Threshold for artefact brightness. + threshold_artefact_size_percent : int, optional + Threshold for artefact size, if the artefcact is smaller than this percentage of the neurons it will be removed. + contrast_power : int, optional + Power for contrast enhancement. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different artefact will be labelled as a different value. + remove_true_labels : bool, optional + If True, the true labels will be removed from the artefacts. + Returns + ------- + ndarray + Label image with pseudo nucleus labelled with 1 value per artefact. + """ + + neurons = np.array(labels > 0) + non_neurons = np.array(labels == 0) + + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + # calculate the percentile of the intensity of all the pixels that are labeled as neurons + # check if the neurons are not empty + if np.sum(neurons) > 0: + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + else: + # take the percentile of the non neurons if the neurons are empty + threshold = np.percentile(image[non_neurons], 90) + + # modify the contrast of the image accoring to the threshold with a tanh function and map the values to [0,1] + + image_contrasted = np.tanh((image - threshold) * contrast_power) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) + + artefacts = binary_watershed( + image_contrasted, thres_seeding=0.95, thres_small=15, thres_objects=0.4 + ) + + if remove_true_labels: + # evaluate where the artefacts are connected to the neurons + # map the artefacts label to the neurons label + map_labels_existing, new_labels = map_labels(labels, artefacts) + + # remove the artefacts that are connected to the neurons + for i in map_labels_existing: + artefacts[artefacts == i[0]] = 0 + # remove all the pixels of the neurons from the artefacts + artefacts = np.where(labels > 0, 0, artefacts) + + # remove the artefacts that are too small + # calculate the percentile of the size of the neurons + if np.sum(neurons) > 0: + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + else: + # find the size of each connected component + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + # remove the smallest connected components + neurone_size_percentile = np.percentile(sizes, 95) + + # select the artefacts that are bigger than the percentile + + artefacts = select_artefacts_by_size( + artefacts, min_size=neurone_size_percentile, is_labeled=True + ) + + # relabel with the label value if the artefacts are not multi label + if not do_multi_label: + artefacts = np.where(artefacts > 0, label_value, artefacts) + + return artefacts + + +def select_artefacts_by_size(artefacts, min_size, is_labeled=False): + """Select artefacts by size. + Parameters + ---------- + artefacts : ndarray + Label image with artefacts labelled as 1. + min_size : int, optional + Minimum size of artefacts to keep + is_labeled : bool, optional + If True, the artefacts are already labelled. + Returns + ------- + ndarray + Label image with artefacts labelled and small artefacts removed. + """ + if not is_labeled: + # find all the connected components in the artefacts image + labels = ndimage.label(artefacts)[0] + else: + labels = artefacts + + # remove the small components + labels_i, counts = np.unique(labels, return_counts=True) + labels_i = labels_i[counts > min_size] + labels_i = labels_i[labels_i > 0] + artefacts = np.where(np.isin(labels, labels_i), labels, 0) + return artefacts + + +def create_artefact_labels( + image_path, + labels_path, + output_path, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. + Parameters + ---------- + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. + output_path : str + Path to save the output label image file. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + image = imread(image_path) + labels = imread(labels_path) + + artefacts = make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power=contrast_power, + label_value=2, + do_multi_label=False, + ) + + neurons_artefacts_labels = np.where(labels > 0, 1, artefacts) + imwrite(output_path, neurons_artefacts_labels) + + +def visualize_images(paths): + """Visualize images. + Parameters + ---------- + paths : list + List of paths to images to visualize. + """ + viewer = napari.Viewer(ndisplay=3) + for path in paths: + viewer.add_image(imread(path), name=os.path.basename(path)) + # wait for the user to close the viewer + napari.run() + + +def create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1 for all images in a folder. The images created are stored in a folder artefact_neurons. + Parameters + ---------- + path : str + Path to folder with images in folder volumes and labels in folder lab_sem. The images are expected to have the same alphabetical order in both folders. + do_visualize : bool, optional + If True, the images will be visualized. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + # find all the images in the folder and create a list + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + # sort the list + path_labels.sort() + path_images.sort() + # create the output folder + os.makedirs(path + "/artefact_neurons", exist_ok=True) + # create the artefact labels + for i in range(len(path_images)): + print(path_labels[i]) + # consider that the images and the labels have names in the same alphabetical order + create_artefact_labels( + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power, + ) + if do_visualize: + visualize_images( + [ + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + ] + ) + + +if __name__ == "__main__": + + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py new file mode 100644 index 00000000..f94327e2 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -0,0 +1,320 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings +from napari.qt.threading import thread_worker +from tqdm import tqdm +import threading +# import sys +# sys.path.append(str(Path(__file__) / "../../")) + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +""" +New code by Yves Paychère +Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold +""" + + +def relabel_non_unique_i(label, save_path, go_fast=False): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + label : np.array + the label image + save_path : str + the path to save the relabeld image + """ + value_label = 0 + new_labels = np.zeros_like(label) + map_labels_existing = [] + unique_label = np.unique(label) + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + i = unique_label[i_label] + if i == 0: + continue + if go_fast: + new_label, to_add = ndimage.label(label == i) + map_labels_existing.append( + [i, list(range(value_label + 1, value_label + to_add + 1))] + ) + + else: + # catch the warning of the watershed + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_label = binary_watershed(label == i) + unique = np.unique(new_label) + to_add = unique[-1] + map_labels_existing.append([i, unique[1:] + value_label]) + + new_label[new_label != 0] += value_label + new_labels += new_label + value_label += to_add + + imwrite(save_path, new_labels) + return map_labels_existing + + +def add_label(old_label, artefact, new_label_path, i_labels_to_add): + """add the label to the label image + Parameters + ---------- + old_label : np.array + the label image + artefact : np.array + the artefact image that contains some neurons + new_label_path : str + the path to save the new label image + """ + new_label = old_label.copy() + max_label = np.max(old_label) + for i, i_label in enumerate(i_labels_to_add): + new_label[artefact == i_label] = i + max_label + 1 + imwrite(new_label_path, new_label) + + +returns = [] + + +def ask_labels(unique_artefact): + global returns + returns = [] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + + if i_labels_to_add_tmp == [0]: + print("no label added") + returns = [[]] + print("close the napari window to continue") + return + + for i in i_labels_to_add_tmp: + if i == 0: + print("0 is not a valid label") + # delete the 0 + i_labels_to_add_tmp.remove(i) + # test if all index are negative + if all(i < 0 for i in i_labels_to_add_tmp): + print( + "all labels are negative-> will add all the labels except the one you gave" + ) + i_labels_to_add = list(unique_artefact) + for i in i_labels_to_add_tmp: + if np.abs(i) in i_labels_to_add: + i_labels_to_add.remove(np.abs(i)) + else: + print("the label", np.abs(i), "is not in the label image") + i_labels_to_add_tmp = i_labels_to_add + else: + # remove the negative index + for i in i_labels_to_add_tmp: + if i < 0: + i_labels_to_add_tmp.remove(i) + print( + "ignore the negative label", + i, + " since not all the labels are negative", + ) + if i not in unique_artefact: + print("the label", i, "is not in the label image") + i_labels_to_add_tmp.remove(i) + + returns = [i_labels_to_add_tmp] + print("close the napari window to continue") + + +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + image_path : str + the path to the image + label_path : str + the path to the label image + go_fast : bool, optional + if True, the relabeling will be faster but the labels can more frequently be merged, by default False + check_for_unicity : bool, optional + if True, the relabeling will check if the labels are unique, by default True + delay : float, optional + the delay between each image for the visualization, by default 0.3 + """ + global returns + + label = imread(label_path) + initial_label_path = label_path + if check_for_unicity: + # check if the label are unique + new_label_path = label_path[:-4] + "_relabel_unique.tif" + map_labels_existing = relabel_non_unique_i( + label, new_label_path, go_fast=go_fast + ) + print( + "visualize the relabeld image in white the previous labels and in red the new labels" + ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + label_path = new_label_path + # detect artefact + print("detection of potential neurons (in progress)") + image = imread(image_path) + artefact = make_artefact_labels.make_artefact_labels( + image, + imread(label_path), + do_multi_label=True, + threshold_artefact_brightness_percent=30, + threshold_artefact_size_percent=0, + contrast_power=30, + ) + print("detection of potential neurons (done)") + # ask the user if the artefact are not neurons + i_labels_to_add = [] + loop = True + unique_artefact = list(np.unique(artefact)) + while loop: + # visualize the artefact and ask the user which label to add to the label image + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t.start() + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="potential neurons") + viewer.add_labels(imread(label_path), name="labels") + napari.run() + t.join() + i_labels_to_add_tmp = returns[0] + # check if the selected labels are neurones + for i in i_labels_to_add: + if i not in i_labels_to_add_tmp: + i_labels_to_add_tmp.append(i) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + print("these labels will be added") + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if revert != "y": + i_labels_to_add = i_labels_to_add_tmp + for i in i_labels_to_add: + if i in unique_artefact: + unique_artefact.remove(i) + loop = input("Do you want to add more labels? (y/n)") == "y" + # add the label to the label image + new_label_path = initial_label_path[:-4] + "_new_label.tif" + print("the new label will be saved in", new_label_path) + add_label(imread(label_path), artefact, new_label_path, i_labels_to_add) + # store the artefact remaining + new_artefact_path = initial_label_path[:-4] + "_artefact.tif" + artefact = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + imwrite(new_artefact_path, artefact) + + +def modify_viewer(old_label, new_label, args): + """modify the viewer to show the relabeling + Parameters + ---------- + old_label : napari.layers.Labels + the layer of the old label + new_label : napari.layers.Labels + the layer of the new label + args : list + the first element is the old label and the second element is the new label + """ + if args == "hide new label": + new_label.visible = False + elif args == "show new label": + new_label.visible = True + else: + old_label.selected_label = args[0] + if not np.isnan(args[1]): + new_label.selected_label = args[1] + + +@thread_worker +def to_show(map_labels_existing, delay=0.5): + """modify the viewer to show the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the of the map between the old label and the new label + delay : float, optional + the delay between each image for the visualization, by default 0.3 + """ + time.sleep(2) + for i in map_labels_existing: + yield "hide new label" + if len(i[1]): + yield [i[0], i[1][0]] + else: + yield [i[0], np.nan] + time.sleep(delay) + yield "show new label" + for j in i[1]: + yield [i[0], j] + time.sleep(delay) + + +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): + """Builds a widget that can control a function in another thread.""" + + worker = to_show(map_labels_existing, delay) + worker.start() + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + + +def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): + """visualize the map of the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the relabeling + """ + label = imread(label_path) + relabel = imread(relabel_path) + + viewer = napari.Viewer(ndisplay=3) + + old_label = viewer.add_labels(label, num_colors=3) + new_label = viewer.add_labels(relabel, num_colors=3) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + + # viewer.dims.ndisplay = 3 + viewer.camera.angles = (180, 3, 50) + viewer.camera.zoom = 1 + + old_label.show_selected_label = True + new_label.show_selected_label = True + + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + napari.run() + + +def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + folder_path : str + the path to the folder containing the label images + end_of_new_name : str + thename to add at the end of the relabled image + """ + for file in Path.iterdir(folder_path): + if file.suffix == ".tif": + label = imread(str(Path(folder_path / file))) + relabel_non_unique_i( + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + ) + + +if __name__ == "__main__": + + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py new file mode 100644 index 00000000..857bcd19 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -0,0 +1,276 @@ +import numpy as np +import pandas as pd +from tqdm import tqdm +import napari + +from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels + + +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): + """Evaluate the model performance. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + do_print : bool + If True, print the results. + Returns + ------- + neuron_found : float + The number of neurons found by the model + neuron_fused: float + The number of neurons fused by the model + neuron_not_found: float + The number of neurons not found by the model + neuron_artefact: float + The number of artefact that the model wrongly labelled as neurons + mean_true_positive_ratio_model: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the model's label) + mean_ratio_pixel_found: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_ratio_pixel_found_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_true_positive_ratio_model_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels in any fused neurons of this model's label)/(total number of pixels of the model's label) + mean_ratio_false_pixel_artefact: float + The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) + """ + log.debug("Mapping labels...") + map_labels_existing, map_fused_neurons, new_labels = map_labels( + labels, model_labels + ) + + # calculate the number of neurons individually found + neurons_found = len(map_labels_existing) + # calculate the number of neurons fused + neurons_fused = len(map_fused_neurons) + # calculate the number of neurons not found + log.debug("Calculating the number of neurons not found...") + neurons_found_labels = np.unique( + [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] + ) + unique_labels = np.unique(labels) + neurons_not_found = len(unique_labels) - 1 - len(neurons_found_labels) + # artefacts found + artefacts_found = len(new_labels) + if len(map_labels_existing) > 0: + # calculate the mean true positive ratio of the model + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + # calculate the mean ratio of the neurons pixels correctly labelled + mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) + else: + mean_true_positive_ratio_model = np.nan + mean_ratio_pixel_found = np.nan + + if len(map_fused_neurons) > 0: + # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + # calculate the mean true positive ratio of the model for the fused neurons + mean_true_positive_ratio_model_fused = np.mean( + [i[3] for i in map_fused_neurons] + ) + else: + mean_ratio_pixel_found_fused = np.nan + mean_true_positive_ratio_model_fused = np.nan + + # calculate the mean false positive ratio of each artefact + if len(new_labels) > 0: + mean_ratio_false_pixel_artefact = np.mean([i[1] for i in new_labels]) + else: + mean_ratio_false_pixel_artefact = np.nan + + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( + "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, + ) + print( + "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, + ) + print( + "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, + ) + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + ) + if visualize: + viewer = napari.Viewer() + viewer.add_labels(labels, name="ground truth") + viewer.add_labels(model_labels, name="model's labels") + found_model = np.where( + np.isin(model_labels, [i[0] for i in map_labels_existing]), + model_labels, + 0, + ) + viewer.add_labels(found_model, name="model's labels found") + found_label = np.where( + np.isin(labels, [i[1] for i in map_labels_existing]), labels, 0 + ) + viewer.add_labels(found_label, name="ground truth found") + neurones_not_found_labels = np.where( + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + ) + neurones_not_found_labels = neurones_not_found_labels[ + neurones_not_found_labels != 0 + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + viewer.add_labels(not_found, name="ground truth not found") + artefacts_found = np.where( + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + ) + viewer.add_labels(artefacts_found, name="model's labels artefacts") + fused_model = np.where( + np.isin(model_labels, [i[0] for i in map_fused_neurons]), + model_labels, + 0, + ) + viewer.add_labels(fused_model, name="model's labels fused") + fused_label = np.where( + np.isin(labels, [i[1] for i in map_fused_neurons]), labels, 0 + ) + viewer.add_labels(fused_label, name="ground truth fused") + napari.run() + + return ( + neurons_found, + neurons_fused, + neurons_not_found, + artefacts_found, + mean_true_positive_ratio_model, + mean_ratio_pixel_found, + mean_ratio_pixel_found_fused, + mean_true_positive_ratio_model_fused, + mean_ratio_false_pixel_artefact, + ) + + +def save_as_csv(results, path): + """ + Save the results as a csv file + + Parameters + ---------- + results: list + The results of the evaluation + path: str + The path to save the csv file + """ + print(np.array(results).shape) + df = pd.DataFrame( + [results], + columns=[ + "neurons_found", + "neurons_fused", + "neurons_not_found", + "artefacts_found", + "mean_true_positive_ratio_model", + "mean_ratio_pixel_found", + "mean_ratio_pixel_found_fused", + "mean_true_positive_ratio_model_fused", + "mean_ratio_false_pixel_artefact", + ], + ) + df.to_csv(path, index=False) + + +# if __name__ == "__main__": +# """ +# # Example of how to use the functions in this module. +# a = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# b = np.array([[5, 5, 0, 0], [5, 5, 2, 0], [0, 2, 2, 0], [0, 0, 2, 0]]) +# evaluate_model_performance(a, b) +# +# c = np.array([[2, 2, 0, 0], [2, 2, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# d = np.array([[4, 0, 4, 0], [4, 4, 4, 0], [0, 4, 4, 0], [0, 0, 4, 0]]) +# +# evaluate_model_performance(c, d) +# +# from tifffile import imread +# labels=imread("dataset/visual_tif/labels/testing_im_new_label.tif") +# labels_model=imread("dataset/visual_tif/artefact_neurones/basic_model.tif") +# evaluate_model_performance(labels, labels_model,visualize=True) +# """ +# from tifffile import imread +# +# labels = imread("dataset_clean/VALIDATION/validation_labels.tif") +# try: +# labels_model = imread("results/watershed_based_model/instance_labels.tif") +# except: +# raise Exception( +# "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" +# ) +# +# evaluate_model_performance(labels, labels_model, visualize=True) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From 2b2c2c0b5fd307f9afc0decc13aa5ee6a7a5196b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 108/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../dev_scripts/evaluate_labels.py | 22 +- notebooks/assess_instance.ipynb | 408 ++++++++++++------ 2 files changed, 301 insertions(+), 129 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..b4436ccb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -4,6 +4,7 @@ import napari from napari_cellseg3d.utils import LOGGER as log + def map_labels(labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -33,10 +34,12 @@ def map_labels(labels, model_labels): unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + print(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + print(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -50,8 +53,7 @@ def map_labels(labels, model_labels): tmp_map.append( [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,12 +61,14 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # print(f"map_labels_existing: {map_labels_existing}") + print(f"map_fused_neurons: {map_fused_neurons}") + # print(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels @@ -99,7 +103,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -109,7 +113,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..6e6a9b5f 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -111,17 +111,274 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "i: 1\n", + "unique: 1\n", + "i: 2\n", + "unique: 2\n", + "i: 3\n", + "unique: 3\n", + "i: 4\n", + "unique: 4\n", + "i: 5\n", + "unique: 5\n", + "i: 6\n", + "unique: 6\n", + "i: 7\n", + "unique: 7\n", + "i: 8\n", + "unique: 8\n", + "i: 9\n", + "unique: 9\n", + "i: 10\n", + "unique: 10\n", + "i: 11\n", + "unique: 11\n", + "i: 12\n", + "unique: 12\n", + "i: 13\n", + "unique: 13\n", + "i: 14\n", + "unique: 14\n", + "i: 15\n", + "unique: 15\n", + "i: 16\n", + "unique: 16\n", + "i: 17\n", + "unique: 17\n", + "i: 18\n", + "unique: 18\n", + "i: 19\n", + "unique: 19\n", + "i: 20\n", + "unique: 20\n", + "i: 21\n", + "unique: 21\n", + "i: 22\n", + "unique: 22\n", + "i: 23\n", + "unique: 23\n", + "i: 24\n", + "unique: 24\n", + "i: 25\n", + "unique: 25\n", + "i: 26\n", + "unique: 26\n", + "i: 27\n", + "unique: 27\n", + "i: 28\n", + "unique: 28\n", + "i: 29\n", + "unique: 29\n", + "i: 30\n", + "unique: 30\n", + "i: 31\n", + "unique: 31\n", + "i: 32\n", + "unique: 32\n", + "i: 33\n", + "unique: 33\n", + "i: 34\n", + "unique: 34\n", + "i: 35\n", + "unique: 35\n", + "i: 36\n", + "unique: 36\n", + "i: 37\n", + "unique: 37\n", + "i: 38\n", + "unique: 38\n", + "i: 39\n", + "unique: 39\n", + "i: 40\n", + "unique: 40\n", + "i: 41\n", + "unique: 41\n", + "i: 42\n", + "unique: 42\n", + "i: 43\n", + "unique: 43\n", + "i: 44\n", + "unique: 44\n", + "i: 45\n", + "unique: 45\n", + "i: 46\n", + "unique: 46\n", + "i: 47\n", + "unique: 47\n", + "i: 48\n", + "unique: 48\n", + "i: 49\n", + "unique: 49\n", + "i: 50\n", + "unique: 50\n", + "i: 51\n", + "unique: 51\n", + "i: 52\n", + "unique: 52\n", + "i: 53\n", + "unique: 53\n", + "i: 54\n", + "unique: 54\n", + "i: 55\n", + "unique: 55\n", + "i: 56\n", + "unique: 56\n", + "i: 57\n", + "unique: 57\n", + "i: 58\n", + "unique: 58\n", + "i: 59\n", + "unique: 59\n", + "i: 60\n", + "unique: 60\n", + "i: 61\n", + "unique: 61\n", + "i: 62\n", + "unique: 62\n", + "i: 63\n", + "unique: 63\n", + "i: 64\n", + "unique: 64\n", + "i: 65\n", + "unique: 65\n", + "i: 66\n", + "unique: 66\n", + "i: 67\n", + "unique: 67\n", + "i: 68\n", + "unique: 68\n", + "i: 69\n", + "unique: 69\n", + "i: 70\n", + "unique: 70\n", + "i: 71\n", + "unique: 71\n", + "i: 72\n", + "unique: 72\n", + "i: 73\n", + "unique: 73\n", + "i: 74\n", + "unique: 74\n", + "i: 75\n", + "unique: 75\n", + "i: 76\n", + "unique: 76\n", + "i: 77\n", + "unique: 77\n", + "i: 78\n", + "unique: 78\n", + "i: 79\n", + "unique: 79\n", + "i: 80\n", + "unique: 80\n", + "i: 81\n", + "unique: 81\n", + "i: 82\n", + "unique: 82\n", + "i: 83\n", + "unique: 83\n", + "i: 84\n", + "unique: 84\n", + "i: 85\n", + "unique: 85\n", + "i: 86\n", + "unique: 86\n", + "i: 87\n", + "unique: 87\n", + "i: 88\n", + "unique: 88\n", + "i: 89\n", + "unique: 89\n", + "i: 90\n", + "unique: 90\n", + "i: 91\n", + "unique: 91\n", + "i: 93\n", + "unique: 93\n", + "i: 94\n", + "unique: 94\n", + "i: 95\n", + "unique: 95\n", + "i: 96\n", + "unique: 96\n", + "i: 97\n", + "unique: 97\n", + "i: 98\n", + "unique: 98\n", + "i: 99\n", + "unique: 99\n", + "i: 100\n", + "unique: 100\n", + "i: 101\n", + "unique: 101\n", + "i: 102\n", + "unique: 102\n", + "i: 103\n", + "unique: 103\n", + "i: 104\n", + "unique: 104\n", + "i: 105\n", + "unique: 105\n", + "i: 106\n", + "unique: 106\n", + "i: 107\n", + "unique: 107\n", + "i: 108\n", + "unique: 108\n", + "i: 109\n", + "unique: 109\n", + "i: 110\n", + "unique: 110\n", + "i: 111\n", + "unique: 111\n", + "i: 112\n", + "unique: 112\n", + "i: 113\n", + "unique: 113\n", + "i: 114\n", + "unique: 114\n", + "i: 115\n", + "unique: 115\n", + "i: 116\n", + "unique: 116\n", + "i: 117\n", + "unique: 117\n", + "i: 118\n", + "unique: 118\n", + "i: 119\n", + "unique: 119\n", + "i: 120\n", + "unique: 120\n", + "i: 121\n", + "unique: 121\n", + "i: 122\n", + "unique: 122\n", + "i: 123\n", + "unique: 123\n", + "i: 124\n", + "unique: 124\n", + "i: 125\n", + "unique: 125\n", + "map_fused_neurons: []\n", + "Calculating the number of neurons not found...\n", "Neurons found: 124\n", "Neurons fused: 0\n", "Neurons not found: 0\n", @@ -157,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -168,145 +425,66 @@ { "data": { "text/plain": [ - "" + "dtype('int32')" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected,name='connected')\n", + "connected.dtype" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -320,7 +498,7 @@ { "data": { "text/plain": [ - "(25, 64, 64)" + "dtype('int64')" ] }, "execution_count": 9, @@ -329,14 +507,12 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" + "gt_labels_resized.dtype" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -353,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -374,15 +550,7 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" - ] - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] From e870fa3db6b8596203f13706551d91e4af94d73e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 109/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../code_models/model_instance_seg.py | 2 +- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 +- .../dev_scripts/evaluate_labels.py | 282 +++++++-- napari_cellseg3d/utils.py | 2 +- notebooks/assess_instance.ipynb | 553 ++++++++---------- requirements.txt | 4 +- 7 files changed, 568 insertions(+), 353 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index a8bb240b..77e5c981 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -136,7 +136,7 @@ def voronoi_otsu( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return instance + return np.array(instance) def binary_connected( diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b4436ccb..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,15 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -22,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -29,17 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 - print(f"i: {i}") + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - print(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -47,14 +228,20 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if len(tmp_map) == 1: # map to only one true neuron -> found neuron map_labels_existing.append(tmp_map[0]) @@ -62,17 +249,21 @@ def map_labels(labels, model_labels): # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map - # print(f"map_labels_existing: {map_labels_existing}") - print(f"map_fused_neurons: {map_fused_neurons}") - # print(f"new_labels: {new_labels}") + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -82,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -103,7 +296,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -113,7 +306,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) @@ -123,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -132,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -148,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -183,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -230,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 6a3b57d3..b2a40e0c 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -133,7 +133,7 @@ def resize(image, zoom_factors): mode="nearest-exact", padding_mode="empty", )(np.expand_dims(image, axis=0)) - return isotropic_image[0] + return isotropic_image[0].numpy() def align_array_sizes(array_shape, target_shape): diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 6e6a9b5f..d521c395 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -115,279 +110,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "i: 1\n", - "unique: 1\n", - "i: 2\n", - "unique: 2\n", - "i: 3\n", - "unique: 3\n", - "i: 4\n", - "unique: 4\n", - "i: 5\n", - "unique: 5\n", - "i: 6\n", - "unique: 6\n", - "i: 7\n", - "unique: 7\n", - "i: 8\n", - "unique: 8\n", - "i: 9\n", - "unique: 9\n", - "i: 10\n", - "unique: 10\n", - "i: 11\n", - "unique: 11\n", - "i: 12\n", - "unique: 12\n", - "i: 13\n", - "unique: 13\n", - "i: 14\n", - "unique: 14\n", - "i: 15\n", - "unique: 15\n", - "i: 16\n", - "unique: 16\n", - "i: 17\n", - "unique: 17\n", - "i: 18\n", - "unique: 18\n", - "i: 19\n", - "unique: 19\n", - "i: 20\n", - "unique: 20\n", - "i: 21\n", - "unique: 21\n", - "i: 22\n", - "unique: 22\n", - "i: 23\n", - "unique: 23\n", - "i: 24\n", - "unique: 24\n", - "i: 25\n", - "unique: 25\n", - "i: 26\n", - "unique: 26\n", - "i: 27\n", - "unique: 27\n", - "i: 28\n", - "unique: 28\n", - "i: 29\n", - "unique: 29\n", - "i: 30\n", - "unique: 30\n", - "i: 31\n", - "unique: 31\n", - "i: 32\n", - "unique: 32\n", - "i: 33\n", - "unique: 33\n", - "i: 34\n", - "unique: 34\n", - "i: 35\n", - "unique: 35\n", - "i: 36\n", - "unique: 36\n", - "i: 37\n", - "unique: 37\n", - "i: 38\n", - "unique: 38\n", - "i: 39\n", - "unique: 39\n", - "i: 40\n", - "unique: 40\n", - "i: 41\n", - "unique: 41\n", - "i: 42\n", - "unique: 42\n", - "i: 43\n", - "unique: 43\n", - "i: 44\n", - "unique: 44\n", - "i: 45\n", - "unique: 45\n", - "i: 46\n", - "unique: 46\n", - "i: 47\n", - "unique: 47\n", - "i: 48\n", - "unique: 48\n", - "i: 49\n", - "unique: 49\n", - "i: 50\n", - "unique: 50\n", - "i: 51\n", - "unique: 51\n", - "i: 52\n", - "unique: 52\n", - "i: 53\n", - "unique: 53\n", - "i: 54\n", - "unique: 54\n", - "i: 55\n", - "unique: 55\n", - "i: 56\n", - "unique: 56\n", - "i: 57\n", - "unique: 57\n", - "i: 58\n", - "unique: 58\n", - "i: 59\n", - "unique: 59\n", - "i: 60\n", - "unique: 60\n", - "i: 61\n", - "unique: 61\n", - "i: 62\n", - "unique: 62\n", - "i: 63\n", - "unique: 63\n", - "i: 64\n", - "unique: 64\n", - "i: 65\n", - "unique: 65\n", - "i: 66\n", - "unique: 66\n", - "i: 67\n", - "unique: 67\n", - "i: 68\n", - "unique: 68\n", - "i: 69\n", - "unique: 69\n", - "i: 70\n", - "unique: 70\n", - "i: 71\n", - "unique: 71\n", - "i: 72\n", - "unique: 72\n", - "i: 73\n", - "unique: 73\n", - "i: 74\n", - "unique: 74\n", - "i: 75\n", - "unique: 75\n", - "i: 76\n", - "unique: 76\n", - "i: 77\n", - "unique: 77\n", - "i: 78\n", - "unique: 78\n", - "i: 79\n", - "unique: 79\n", - "i: 80\n", - "unique: 80\n", - "i: 81\n", - "unique: 81\n", - "i: 82\n", - "unique: 82\n", - "i: 83\n", - "unique: 83\n", - "i: 84\n", - "unique: 84\n", - "i: 85\n", - "unique: 85\n", - "i: 86\n", - "unique: 86\n", - "i: 87\n", - "unique: 87\n", - "i: 88\n", - "unique: 88\n", - "i: 89\n", - "unique: 89\n", - "i: 90\n", - "unique: 90\n", - "i: 91\n", - "unique: 91\n", - "i: 93\n", - "unique: 93\n", - "i: 94\n", - "unique: 94\n", - "i: 95\n", - "unique: 95\n", - "i: 96\n", - "unique: 96\n", - "i: 97\n", - "unique: 97\n", - "i: 98\n", - "unique: 98\n", - "i: 99\n", - "unique: 99\n", - "i: 100\n", - "unique: 100\n", - "i: 101\n", - "unique: 101\n", - "i: 102\n", - "unique: 102\n", - "i: 103\n", - "unique: 103\n", - "i: 104\n", - "unique: 104\n", - "i: 105\n", - "unique: 105\n", - "i: 106\n", - "unique: 106\n", - "i: 107\n", - "unique: 107\n", - "i: 108\n", - "unique: 108\n", - "i: 109\n", - "unique: 109\n", - "i: 110\n", - "unique: 110\n", - "i: 111\n", - "unique: 111\n", - "i: 112\n", - "unique: 112\n", - "i: 113\n", - "unique: 113\n", - "i: 114\n", - "unique: 114\n", - "i: 115\n", - "unique: 115\n", - "i: 116\n", - "unique: 116\n", - "i: 117\n", - "unique: 117\n", - "i: 118\n", - "unique: 118\n", - "i: 119\n", - "unique: 119\n", - "i: 120\n", - "unique: 120\n", - "i: 121\n", - "unique: 121\n", - "i: 122\n", - "unique: 122\n", - "i: 123\n", - "unique: 123\n", - "i: 124\n", - "unique: 124\n", - "i: 125\n", - "unique: 125\n", - "map_fused_neurons: []\n", - "Calculating the number of neurons not found...\n", - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -414,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -428,66 +165,177 @@ "dtype('int32')" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')\n", + "viewer.add_labels(connected, name=\"connected\")\n", "connected.dtype" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -501,7 +349,7 @@ "dtype('int64')" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -512,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -522,42 +370,155 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { diff --git a/requirements.txt b/requirements.txt index 93da070f..834a225e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ black coverage isort +itk pytest pytest-qt sphinx @@ -18,6 +19,7 @@ matplotlib>=3.4.1 tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,scikit-image,einops]>=0.9.0 +monai[nibabel,einops]>=1.0.1 pillow +scikit-image>=0.19.2 vispy>=0.9.6 From 7548e13ac6f1a349855700a6c9cf4db7a65cea6d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 110/577] black --- .../code_models/model_instance_seg.py | 21 ++++++++---- napari_cellseg3d/code_models/model_workers.py | 4 ++- .../code_plugins/plugin_model_inference.py | 8 +++-- napari_cellseg3d/config.py | 2 ++ .../dev_scripts/evaluate_labels.py | 33 +++++++++++-------- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 77e5c981..f3a04059 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -33,7 +33,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None + widget_parent: QWidget = None, ): """ Methods for instance segmentation @@ -56,7 +56,14 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), + ui.Slider( + 0, + 100, + 1, + divide_factor=100, + text_label="", + parent=None, + ), ) self.sliders.append(getattr(self, widget)) @@ -365,13 +372,13 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -411,13 +418,13 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -448,7 +455,7 @@ def __init__(self, widget_parent): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 47f63fec..1b4630cc 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -543,7 +543,9 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct + instance_labels = np.swapaxes( + instance_labels, 0, 2 + ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index a6a90eb4..c9b59357 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -555,7 +555,9 @@ def start(self): self.instance_config = config.InstanceSegConfig( enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], ) self.post_process_config = config.PostProcessConfig( @@ -724,7 +726,9 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 107af8e6..6e9cc89e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -119,11 +119,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class PostProcessConfig: zoom: Zoom = Zoom() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..1aa52932 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -10,11 +10,14 @@ PERCENT_CORRECT = 0.7 + @dataclass class LabelInfo: gt_index: int model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + best_model_label_coverage: float = ( + 0.0 # ratio of pixels of the gt label correctly labelled + ) overall_gt_label_coverage: float = 0.0 # true positive ration of the model def get_correct_ratio(self): @@ -24,16 +27,25 @@ def get_correct_ratio(self): else: return None + def eval_model(gt_labels, model_labels, print_report=False): - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + report_list, new_labels, fused_labels = create_label_report( + gt_labels, model_labels + ) per_label_perfs = [] for report in report_list: if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + log.info( + f"Label {report.gt_index} : {report.model_labels_id_and_status}" + ) + log.info( + f"Best model label coverage : {report.best_model_label_coverage}" + ) + log.info( + f"Overall gt label coverage : {report.overall_gt_label_coverage}" + ) perf = report.get_correct_ratio() if perf is not None: @@ -43,8 +55,6 @@ def eval_model(gt_labels, model_labels, print_report=False): return per_label_perfs.mean(), new_labels, fused_labels - - def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -63,7 +73,6 @@ def create_label_report(gt_labels, model_labels): The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ - map_labels_existing = [] map_fused_neurons = {} "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" @@ -135,9 +144,7 @@ def create_label_report(gt_labels, model_labels): # log.debug(ratio) ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) + info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() best_model_lab_id = model_lab_on_gt[ np.argmax(ratio_model_lab_for_given_gt_lab) @@ -369,9 +376,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: From 8df0c011e76b8e08835db35e6889ff4134cfb2c6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 111/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 564 +++++++++--------- notebooks/assess_instance.ipynb | 290 ++++----- 2 files changed, 385 insertions(+), 469 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 1aa52932..3082e79f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,275 +1,15 @@ import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = ( - 0.0 # ratio of pixels of the gt label correctly labelled - ) - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report( - gt_labels, model_labels - ) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info( - f"Label {report.gt_index} : {report.model_labels_id_and_status}" - ) - log.info( - f"Best model label coverage : {report.best_model_label_coverage}" - ) - log.info( - f"Overall gt label coverage : {report.overall_gt_label_coverage}" - ) - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False ): """Evaluate the model performance. Parameters @@ -278,7 +18,7 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - do_print : bool + print_details : bool If True, print the results. visualize : bool If True, visualize the results. @@ -305,7 +45,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels + labels, model_labels, threshold_correct ) # calculate the number of neurons individually found @@ -351,33 +91,30 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") + log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") + log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") + log.info( + f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + ) log.info( - "Mean true positive ratio of the model: ", + f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" ) - log.info(mean_true_positive_ratio_model) log.info( - "Mean ratio of the neurons pixels correctly labelled: ", + f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" ) - log.info(mean_ratio_pixel_found) log.info( - "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" ) - log.info(mean_ratio_pixel_found_fused) log.info( - "Mean true positive ratio of the model for fused neurons: ", + f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" ) - log.info(mean_true_positive_ratio_model_fused) - log.info("Mean ratio of false pixel in artefacts: ") - log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -436,6 +173,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file @@ -464,6 +276,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons # if __name__ == "__main__": # """ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index d521c395..4bf89452 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,9 +4,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -22,6 +19,7 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", + " to_semantic,\n", ")" ] }, @@ -29,9 +27,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -50,12 +45,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -72,9 +69,7 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" + "viewer.add_labels(gt_labels_resized, name=\"gt\")" ] }, { @@ -84,9 +79,33 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -98,7 +117,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -110,48 +143,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "2\n" ] - }, - { - "data": { - "text/plain": [ - "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)\n", + "print(np.unique(gt_labels_resized).shape[0])" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -162,23 +168,22 @@ { "data": { "text/plain": [ - "dtype('int32')" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized,thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -190,21 +195,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:05,891 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -217,18 +225,10 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" + "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -251,21 +251,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:05,995 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -278,25 +281,17 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", + " prediction_resized, thres_small=2, rem_seed_thres=1\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -304,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -318,24 +313,24 @@ "(25, 64, 64)" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", + "voronoi = remove_small_objects(voronoi, 2)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { @@ -349,7 +344,7 @@ "dtype('int64')" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -360,104 +355,35 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(voronoi, return_counts=True)" + "# np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(gt_labels_resized, return_counts=True)" + "# np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { @@ -469,21 +395,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:06,360 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -496,18 +425,10 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -518,14 +439,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, "outputs": [], From 57fae3eede628a2f3648947323907787d2f11e88 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 112/577] Added pre-commit hooks --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 834a225e..3189e9c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pre-commit pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From d65e4f1c1cc3181ed4b5dbdf0b5f66baf327c54a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 113/577] Enfore pre-commit style --- .gitignore | 1 + .../_tests/test_plugin_inference.py | 2 - .../code_models/model_instance_seg.py | 8 +- .../code_plugins/plugin_model_inference.py | 3 - .../code_plugins/plugin_utilities.py | 1 - napari_cellseg3d/config.py | 3 - .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 23 ++++-- napari_cellseg3d/utils.py | 10 +-- notebooks/assess_instance.ipynb | 79 +++++++++++++------ notebooks/csv_cell_plot.ipynb | 2 - 12 files changed, 78 insertions(+), 56 deletions(-) diff --git a/.gitignore b/.gitignore index d08ff9f2..e86beea4 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ notebooks/full_plot.html *.csv *.png *.prof + diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 584ffd3b..e15958e6 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,8 +7,6 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST - - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index f3a04059..f83bfd4d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,5 +1,3 @@ -from __future__ import division -from __future__ import print_function from dataclasses import dataclass from typing import List import numpy as np @@ -10,6 +8,7 @@ from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread + # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -531,14 +530,13 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index c9b59357..483679ef 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,9 +10,6 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 1f0d598b..6e3b9981 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,7 +2,6 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QLayout from qtpy.QtWidgets import QSizePolicy from qtpy.QtWidgets import QVBoxLayout from qtpy.QtWidgets import QWidget diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 6e9cc89e..3d5f2a1a 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,10 +8,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3082e79f..a972fa69 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -5,11 +5,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -91,9 +95,15 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") - log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") - log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) if print_details: log.info(f"Neurons found: {neurons_found}") @@ -131,7 +141,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -276,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index b2a40e0c..3980502d 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -35,9 +35,7 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) + cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] @@ -412,17 +410,17 @@ def parse_default_path(possible_paths): def get_date_time(): """Get date and time in the following format : year_month_day_hour_minute_second""" - return "{:%Y_%m_%d_%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%Y_%m_%d_%H_%M_%S}" def get_time(): """Get time in the following format : hour:minute:second. NOT COMPATIBLE with file paths (saving with ":" is invalid)""" - return "{:%H:%M:%S}".format(datetime.now()) + return f"{datetime.now():%H:%M:%S}" def get_time_filepath(): """Get time in the following format : hour_minute_second. Compatible with saving""" - return "{:%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%H_%M_%S}" def load_images( diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 4bf89452..b8810301 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -47,7 +47,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -96,7 +96,10 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", - "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" ] }, { @@ -145,7 +148,7 @@ "text": [ "(25, 64, 64)\n", "(25, 64, 64)\n", - "2\n" + "125\n" ] } ], @@ -168,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -177,7 +180,7 @@ } ], "source": [ - "connected = binary_connected(prediction_resized,thres_small=2)\n", + "connected = binary_connected(prediction_resized, thres_small=2)\n", "viewer.add_labels(connected, name=\"connected\")" ] }, @@ -195,24 +198,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,891 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -225,7 +228,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -251,24 +262,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,995 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -281,7 +292,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -395,24 +414,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,360 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -425,7 +444,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 15, diff --git a/notebooks/csv_cell_plot.ipynb b/notebooks/csv_cell_plot.ipynb index 8b14fb8d..e00a9f1c 100644 --- a/notebooks/csv_cell_plot.ipynb +++ b/notebooks/csv_cell_plot.ipynb @@ -58,7 +58,6 @@ "outputs": [], "source": [ "def plot_data(data_path, x_inv=False, y_inv=False, z_inv=False):\n", - "\n", " data = pd.read_csv(data_path, index_col=False)\n", "\n", " x = data[\"Centroid x\"]\n", @@ -185,7 +184,6 @@ "outputs": [], "source": [ "def plotly_cells_stats(data):\n", - "\n", " init_notebook_mode() # initiate notebook for offline plot\n", "\n", " x = data[\"Centroid x\"]\n", From 98ad7798756cbbf50749ddcba2e7b2b90401b4d4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 114/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index e86beea4..755de742 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,4 @@ notebooks/full_plot.html *.png *.prof + From 4761db2bc58d54ad23d564233ca8444227960ec3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 115/577] Version bump --- napari_cellseg3d/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 11e8de0e..736c7f72 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1,2 @@ __version__ = "0.0.2rc6" + From 843cd35c5f52066d3b221bfc032545bd95290910 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 116/577] Updated project files --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 11b8dced..22ba04ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "nibabel", "scikit-image", "pillow", + "pyclesperanto-prototype", "tqdm", "matplotlib", "vispy>=0.9.6", @@ -61,6 +62,7 @@ dev = [ "ruff", "tuna", "pre-commit", + ] docs = [ "sphinx", @@ -75,3 +77,4 @@ test = [ "tox", "twine", ] + From f25a94dd0057789af5b8b80f18928e24dd1705e2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 09:45:17 +0200 Subject: [PATCH 117/577] Fixed missing parent error --- napari_cellseg3d/code_models/model_instance_seg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index f83bfd4d..ccdb5b18 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -448,7 +448,7 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent): + def __init__(self, widget_parent=None): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, From 768e11dcd7ef9d3fcdc00fd9267d0fe0a0ddbdac Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 118/577] Fixed wrong value in instance sliders --- .../code_models/model_instance_seg.py | 35 ++++++++++++------- .../code_plugins/plugin_model_inference.py | 1 + 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index ccdb5b18..979f861c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -138,6 +138,9 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -146,7 +149,7 @@ def voronoi_otsu( def binary_connected( - volume, + volume: np.array, thres=0.5, thres_small=3, ): @@ -158,8 +161,12 @@ def binary_connected( thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 """ + logger.debug( + f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" + ) + # if len(volume.shape) > 3: semantic = np.squeeze(volume) - foreground = semantic > thres # int(255 * thres) + foreground = np.where(semantic > thres, volume, 0) # int(255 * thres) segm = label(foreground) segm = remove_small_objects(segm, thres_small) @@ -202,6 +209,10 @@ def binary_watershed( rem_seed_thres (int): threshold for small seeds removal. Default : 3 """ + logger.debug( + f"Running watershed segmentation with thres_objects={thres_objects}, thres_seeding={thres_seeding}," + f" thres_small={thres_small} and rem_seed_thres={rem_seed_thres}" + ) semantic = np.squeeze(volume) seed_map = semantic > thres_seeding foreground = semantic > thres_objects @@ -407,8 +418,8 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( image, - self.sliders[0].value(), - self.sliders[1].value(), + self.sliders[0].slider_value, + self.sliders[1].slider_value, self.counters[0].value(), self.counters[1].value(), ) @@ -441,7 +452,7 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( - image, self.sliders[0].value(), self.counters[0].value() + image, self.sliders[0].slider_value, self.counters[0].value() ) @@ -501,7 +512,7 @@ def __init__(self, parent=None): """ super().__init__(parent) self.method_choice = ui.DropdownMenu( - INSTANCE_SEGMENTATION_METHOD_LIST.keys() + list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) self.methods = {} """Contains the instance of the method, with its name as key""" @@ -520,7 +531,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -530,8 +541,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() @@ -555,9 +568,7 @@ def run_method(self, volume): Returns: processed image from self._method """ - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() + method = self.methods[self.method_choice.currentText()] return method.run_method(volume) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 483679ef..fb6fb71c 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -184,6 +184,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) + self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## From 67861b543be1b1ebe8119a62b19f923265fa9826 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 119/577] Removing dask-image --- .gitignore | 1 + napari_cellseg3d/dev_scripts/convert.py | 1 - napari_cellseg3d/utils.py | 106 ++++++++++++------------ 3 files changed, 53 insertions(+), 55 deletions(-) diff --git a/.gitignore b/.gitignore index 755de742..f8547d92 100644 --- a/.gitignore +++ b/.gitignore @@ -107,3 +107,4 @@ notebooks/full_plot.html *.prof +*.prof diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index d772a1c2..479a07dd 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -20,7 +20,6 @@ # print(os.path.basename(filename)) for file in paths: image = imread(file) - # image = img.compute() image[image >= 1] = 1 image = image.astype(np.uint16) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 3980502d..02a5865e 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,7 +2,6 @@ import warnings from datetime import datetime from pathlib import Path - import numpy as np from pandas import DataFrame from pandas import Series @@ -16,7 +15,6 @@ # LOGGER.setLevel(logging.DEBUG) LOGGER.setLevel(logging.INFO) ############### - """ utils.py ==================================== @@ -276,52 +274,51 @@ def annotation_to_input(label_ermito): anno = normalize_x(anno[np.newaxis, :, :, :]) return anno - -def check_csv(project_path, ext): - if not Path(Path(project_path) / Path(project_path).name).is_file(): - cols = [ - "project", - "type", - "ext", - "z", - "y", - "x", - "z_size", - "y_size", - "x_size", - "created_date", - "update_date", - "path", - "notes", - ] - df = DataFrame(index=[], columns=cols) - filename_pattern_original = Path(project_path) / Path( - f"dataset/Original_size/Original/*{ext}" - ) - images_original = tfl_imread(filename_pattern_original) - z, y, x = images_original.shape - record = Series( - [ - Path(project_path).name, - "dataset", - ".tif", - 0, - 0, - 0, - z, - y, - x, - datetime.datetime.now(), - "", - Path(project_path) / Path("dataset/Original_size/Original"), - "", - ], - index=df.columns, - ) - df = df.append(record, ignore_index=True) - df.to_csv(Path(project_path) / Path(project_path).name) - else: - pass +# def check_csv(project_path, ext): +# if not Path(Path(project_path) / Path(project_path).name).is_file(): +# cols = [ +# "project", +# "type", +# "ext", +# "z", +# "y", +# "x", +# "z_size", +# "y_size", +# "x_size", +# "created_date", +# "update_date", +# "path", +# "notes", +# ] +# df = DataFrame(index=[], columns=cols) +# filename_pattern_original = Path(project_path) / Path( +# f"dataset/Original_size/Original/*{ext}" +# ) +# images_original = dask_imread(filename_pattern_original) +# z, y, x = images_original.shape +# record = Series( +# [ +# Path(project_path).name, +# "dataset", +# ".tif", +# 0, +# 0, +# 0, +# z, +# y, +# x, +# datetime.datetime.now(), +# "", +# Path(project_path) / Path("dataset/Original_size/Original"), +# "", +# ], +# index=df.columns, +# ) +# df = df.append(record, ignore_index=True) +# df.to_csv(Path(project_path) / Path(project_path).name) +# else: +# pass # def check_annotations_dir(project_path): @@ -464,6 +461,7 @@ def load_images( LOGGER.error( "Loading a stack this way is no longer supported. Use napari to load a stack." ) + else: images_original = tfl_imread( filename_pattern_original @@ -484,12 +482,12 @@ def load_images( # return base_label -def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): - images_label = load_images(mod_mask_dir, filetype, as_folder) - if as_folder: - images_label = images_label.compute() - base_label = images_label - return base_label +# def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): +# images_label = load_images(mod_mask_dir, filetype, as_folder) +# if as_folder: +# images_label = images_label.compute() +# base_label = images_label +# return base_label def save_stack(images, out_path, filetype=".png", check_warnings=False): From bfb1211c1af9cc10b583699ea7ffe7d8b48c2b33 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 120/577] Fixed erroneous dtype conversion --- napari_cellseg3d/code_models/model_instance_seg.py | 13 +++++++++++-- napari_cellseg3d/code_plugins/plugin_convert.py | 12 ++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 979f861c..436135a1 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -137,12 +137,12 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - semantic = np.squeeze(volume) + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) instance = cle.voronoi_otsu_labeling( - semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) return np.array(instance) @@ -489,6 +489,15 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + + ################ + # For debugging + # import napari + # view = napari.Viewer() + # view.add_image(image) + # napari.run() + ################ + return self.function( image, self.counters[0].value(), diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 7a59dcf0..c1493fa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -150,7 +150,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) save_layer( @@ -168,7 +168,7 @@ def _start(self): elif self.folder_choice.isChecked(): if len(self.images_filepaths) != 0: images = [ - utils.resize(np.array(imread(file), dtype=np.int16), zoom) + utils.resize(np.array(imread(file)), zoom) for file in self.images_filepaths ] save_folder( @@ -249,7 +249,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( @@ -330,7 +330,7 @@ def _start(self): if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) semantic = to_semantic(data) save_layer( @@ -414,7 +414,7 @@ def _start(self): if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) instance = self.instance_widgets.run_method(data) save_layer( @@ -509,7 +509,7 @@ def _start(self): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( From 35c72c4c2a2cd85b483d38eed65e1bb1227c44d5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 121/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index a5b6fd94..8fd297c2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,3 +1,7 @@ +from pathlib import Path +from tifffile import imread +import numpy as np + from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, @@ -8,9 +12,15 @@ def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + image = imread(im_path) + view.add_image(image) + view.add_labels(image.astype(np.uint8)) + view.window.add_dock_widget(widget) for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): widget.utils_choice.setCurrentIndex(i) assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + widget.utils_widgets[i]._start() From 3dbb12ac19c7c946fd44f9198244a20b26425f11 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:38:13 +0200 Subject: [PATCH 122/577] Temporary test action patch --- .github/workflows/test_and_deploy.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5dcd11ae..ea0a1e46 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -8,12 +8,14 @@ on: branches: - main - npe2 + - cy/voronoi-otsu tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - npe2 + - cy/voronoi-otsu workflow_dispatch: jobs: From c591b1cc347b91135ecab3323fa01cba445f1ffe Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:50:16 +0200 Subject: [PATCH 123/577] Update plugin_convert.py --- napari_cellseg3d/code_plugins/plugin_convert.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index c1493fa4..6908b7aa 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -34,7 +34,7 @@ def save_folder(results_path, folder_name, images, image_paths): image_paths: list of filenames of images """ results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False) + results_folder.mkdir(exist_ok=False, parents=True) for file, image in zip(image_paths, images): path = results_folder / Path(file).name @@ -143,7 +143,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -242,7 +242,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -324,7 +324,7 @@ def _build(self): ) def _start(self): - Path(self.results_path).mkdir(exist_ok=True) + Path(self.results_path).mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -408,7 +408,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -502,7 +502,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.binarize_counter.value() if self.layer_choice: From 59aaf444cd9a5876eed2654b5191b706fe00a0e3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:02:47 +0200 Subject: [PATCH 124/577] Update tox.ini Added pocl for testing on GH Actions --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 292b8fa4..46d84b40 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = magicgui pytest-qt qtpy + pocl ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 263eb07d34cf692d20589cb61dc695e42fe878db Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 125/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 46d84b40..6ba5efac 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl + pocl-binary-distribution ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 2f2c517c74944d63d485bf810d69f0220011b4c6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:18:52 +0200 Subject: [PATCH 126/577] Found existing pocl --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 6ba5efac..ee946a73 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl-binary-distribution + pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 85834407ae96cb531a248e19c01f02f5b6202c34 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:41:23 +0200 Subject: [PATCH 127/577] Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime --- napari_cellseg3d/_tests/test_plugin_utils.py | 5 +++++ tox.ini | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 8fd297c2..b2d9de52 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -23,4 +23,9 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + widget.utils_widgets[i]._start() diff --git a/tox.ini b/tox.ini index ee946a73..40a2a7a0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pyopencl[pocl] +; pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 31bfa2517f5e9721b98b03defe9879b033f20a59 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 128/577] Relabeling tests --- .gitignore | 6 +- napari_cellseg3d/_tests/res/test_labels.tif | Bin 0 -> 2026 bytes .../_tests/test_labels_correction.py | 51 ++++++++++ .../dev_scripts/artefact_labeling.py | 93 ++++++++---------- .../dev_scripts/correct_labels.py | 75 +++++++++----- 5 files changed, 151 insertions(+), 74 deletions(-) create mode 100644 napari_cellseg3d/_tests/res/test_labels.tif create mode 100644 napari_cellseg3d/_tests/test_labels_correction.py diff --git a/.gitignore b/.gitignore index f8547d92..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -106,5 +106,7 @@ notebooks/full_plot.html *.png *.prof - -*.prof +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/res/test_labels.tif b/napari_cellseg3d/_tests/res/test_labels.tif new file mode 100644 index 0000000000000000000000000000000000000000..0486d789ea658acc32616b40869833accf8d01d7 GIT binary patch literal 2026 zcmcK5yJ}QX6b9gPW+oRk7ZaUC6EDMfA4AYa#WzSNSc-*3f&q(wHX^pxK7x-TK7$V+ zh-hgcUQztNum?}HQam9)Yn^rd*V>ys8yll)x~i)As;YZc9c?nG8+xbi?%D^jcZTs? z;A_lNkw1zQI~mBsOB}G_#)iYWBJ~`{jpd=(^f$j8o7U4Q+J#zTzQ_J1_z;*uteC68 z$zUP)5+6=ue*NfXR{vChyE>l&{pDQ@Rs-|ldNz=U59v(k#{#wVv17fKBh6$ldYFA! zj2TD_w8=&Bcb7Z$2DI=OnVnep55ES3J|ZCRZ7^|q`;Z@w+f_hcAf uO7Fq{VSFjiRU3?7w#N8*ON^i72d14J-{`ip<7-oGF@Dt&<6PlCcKj3h Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 129/577] Run full suite of pre-commit hooks --- README.md | 2 +- napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/_tests/pytest.ini | 2 +- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/model_instance_seg.py | 3 +-- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 2 +- 9 files changed, 29 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index d6037a0d..8cd3fea1 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 814cca2e..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 \ No newline at end of file +qt_api=pyqt5 diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index b2d9de52..8dcd3c7e 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 436135a1..6d0dc13d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -14,8 +14,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -489,7 +489,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index a972fa69..ee9919b6 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,7 @@ +import napari import numpy as np import pandas as pd from tqdm import tqdm -import napari from napari_cellseg3d.utils import LOGGER as log From 08de49ab9e08ac7dc278c69c12b0d7a736c691d7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:08:38 +0200 Subject: [PATCH 130/577] Enforce style --- napari_cellseg3d/__init__.py | 1 - napari_cellseg3d/_tests/test_plugin_inference.py | 1 + napari_cellseg3d/_tests/test_plugin_utils.py | 4 +--- napari_cellseg3d/code_models/model_instance_seg.py | 8 +++++--- napari_cellseg3d/code_models/models/unet/model.py | 4 +--- napari_cellseg3d/code_plugins/plugin_convert.py | 2 ++ napari_cellseg3d/code_plugins/plugin_crop.py | 4 +--- napari_cellseg3d/code_plugins/plugin_review.py | 4 +--- napari_cellseg3d/code_plugins/plugin_utilities.py | 4 +--- napari_cellseg3d/config.py | 1 - napari_cellseg3d/interface.py | 5 +---- pyproject.toml | 1 - 12 files changed, 14 insertions(+), 25 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 736c7f72..11e8de0e 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1,2 +1 @@ __version__ = "0.0.2rc6" - diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index e15958e6..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,6 +7,7 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 8dcd3c7e..cbfd97b2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -4,9 +4,7 @@ from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import ( - UTILITIES_WIDGETS, -) +from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 6d0dc13d..cc362eac 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import List + import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget @@ -9,14 +10,15 @@ from skimage.segmentation import watershed from tifffile import imread -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index c5cc78d3..6cc76be6 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -6,9 +6,7 @@ from napari_cellseg3d.code_models.models.unet.buildingblocks import ( create_encoders, ) -from napari_cellseg3d.code_models.models.unet.buildingblocks import ( - DoubleConv, -) +from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6908b7aa..ed1a43df 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,5 +1,6 @@ import warnings from pathlib import Path + import napari import numpy as np from qtpy.QtWidgets import QSizePolicy @@ -354,6 +355,7 @@ def _start(self): self.images_filepaths, ) + class ToInstanceUtils(BasePluginFolder): """ Widget to convert semantic labels to instance labels diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 406ae7e7..fa4857aa 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -11,9 +11,7 @@ # local from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_plugins.plugin_base import ( - BasePluginSingleImage, -) +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage DEFAULT_CROP_SIZE = 64 logger = utils.LOGGER diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 0044c8e2..a803dfd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -19,9 +19,7 @@ from napari_cellseg3d import config from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_plugins.plugin_base import ( - BasePluginSingleImage, -) +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager warnings.formatwarning = utils.format_Warning diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 6e3b9981..c962717e 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -9,9 +9,7 @@ # local import napari_cellseg3d.interface as ui from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils -from napari_cellseg3d.code_plugins.plugin_convert import ( - RemoveSmallUtils, -) +from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 3d5f2a1a..ab3dba39 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -10,7 +10,6 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index a854905b..bb2a1efb 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -7,11 +7,8 @@ import napari # Qt -from qtpy import QtCore - # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt +from qtpy import QtCore from qtpy.QtCore import QObject from qtpy.QtCore import Qt from qtpy.QtCore import QUrl diff --git a/pyproject.toml b/pyproject.toml index 22ba04ca..253af197 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,4 +77,3 @@ test = [ "tox", "twine", ] - From 2b26e2c700eb4c2fb7d9fcac266805b38e32dd46 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 131/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- .../_tests/test_plugin_inference.py | 1 + .../code_models/model_instance_seg.py | 57 +++++++++++++++---- napari_cellseg3d/code_models/model_workers.py | 1 - .../code_plugins/plugin_convert.py | 1 + napari_cellseg3d/config.py | 9 ++- napari_cellseg3d/interface.py | 3 +- 6 files changed, 58 insertions(+), 14 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..584ffd3b 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,6 +8,7 @@ from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index cc362eac..2c308c5d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -8,12 +8,18 @@ from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed + +from skimage.filters import thresholding +from skimage.transform import resize +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -82,6 +88,42 @@ def run_method(self, image): raise NotImplementedError("Must be defined in child classes") +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -122,7 +164,6 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -390,7 +431,7 @@ def __init__(self, widget_parent=None): function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent, + # widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -436,7 +477,7 @@ def __init__(self, widget_parent=None): function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent, + # widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -475,8 +516,8 @@ def __init__(self, widget_parent=None): ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") # smoothness + self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -504,6 +545,7 @@ def run_method(self, image): self.counters[0].value(), self.counters[1].value(), # self.counters[2].value(), + ) @@ -518,7 +560,6 @@ def __init__(self, parent=None): Args: parent: parent widget - """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -528,14 +569,12 @@ def __init__(self, parent=None): """Contains the instance of the method, with its name as key""" self.instance_widgets = {} """Contains the lists of widgets for each methods, to show/hide""" - self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() def _build(self): group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) @@ -555,11 +594,11 @@ def _build(self): logger.debug( f"Caught runtime error {e}, most likely during testing" ) - self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): + for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: @@ -571,12 +610,10 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters - Args: volume: image data to run method on Returns: processed image from self._method - """ method = self.methods[self.method_choice.currentText()] return method.run_method(volume) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 1b4630cc..636f7acd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -449,7 +449,6 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - def model_output(inputs): return post_process_transforms( self.config.model_info.get_model().get_output(model, inputs) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index ed1a43df..f9d1c801 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -13,6 +13,7 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index ab3dba39..bb714f5d 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -9,7 +9,6 @@ import numpy as np from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -17,6 +16,12 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -126,7 +131,7 @@ class InstanceSegConfig: class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index bb2a1efb..90b102c4 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1049,12 +1049,13 @@ def __init__( self.label = make_label(name=label) self.valueChanged.connect(self._update_step) - def _update_step(self): + def _update_step(self): #FIXME check divide_factor if self.value() < 0.9: self.setSingleStep(0.01) else: self.setSingleStep(0.1) + @property def tooltips(self): return self.toolTip() From d06eb8cf5a684bf7fe74e50d9517b18ed3501655 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 132/577] Disabled small removal in Voronoi-Otsu --- .../code_models/model_instance_seg.py | 40 +------------------ 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 2c308c5d..19b5f5ba 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -87,43 +87,6 @@ def __init__( def run_method(self, image): raise NotImplementedError("Must be defined in child classes") - -class InstanceMethod: - def __init__( - self, - name: str, - function: callable, - num_sliders: int, - num_counters: int, - ): - self.name = name - self.function = function - self.counters: List[ui.DoubleIncrementCounter] = [] - self.sliders: List[ui.Slider] = [] - if num_sliders > 0: - for i in range(num_sliders): - widget = f"slider_{i}" - setattr( - self, - widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), - ) - self.sliders.append(getattr(self, widget)) - - if num_counters > 0: - for i in range(num_counters): - widget = f"counter_{i}" - setattr( - self, - widget, - ui.DoubleIncrementCounter(label=""), - ) - self.counters.append(getattr(self, widget)) - - def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") - - @dataclass class ImageStats: volume: List[float] @@ -431,7 +394,7 @@ def __init__(self, widget_parent=None): function=binary_watershed, num_sliders=2, num_counters=2, - # widget_parent=widget_parent, + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -545,7 +508,6 @@ def run_method(self, image): self.counters[0].value(), self.counters[1].value(), # self.counters[2].value(), - ) From 7b45f325d9506bb5778fe10c5ef0fc7372a13df2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 133/577] Added new docs for instance seg --- napari_cellseg3d/code_models/model_instance_seg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 19b5f5ba..d43625d8 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -51,7 +51,6 @@ def __init__( num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets - """ self.name = name self.function = function @@ -384,11 +383,11 @@ def fill(lst, n=len(properties) - 1): fill([len(properties)]), ) - class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" def __init__(self, widget_parent=None): + super().__init__( name=WATERSHED, function=binary_watershed, @@ -430,11 +429,11 @@ def run_method(self, image): self.counters[1].value(), ) - class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" def __init__(self, widget_parent=None): + super().__init__( name=CONNECTED_COMP, function=binary_connected, @@ -466,6 +465,7 @@ class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" def __init__(self, widget_parent=None): + super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, @@ -480,7 +480,6 @@ def __init__(self, widget_parent=None): self.counters[0].setMaximum(100) self.counters[0].setValue(2) self.counters[1].label.setText("Outline sigma") # smoothness - self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -585,4 +584,5 @@ def run_method(self, volume): VORONOI_OTSU: VoronoiOtsu, WATERSHED: Watershed, CONNECTED_COMP: ConnectedComponents, + } From 33da9733056654c6461619e2415aeb53cc56ae7d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 134/577] isort --- napari_cellseg3d/_tests/test_plugin_inference.py | 1 - napari_cellseg3d/code_models/model_instance_seg.py | 3 ++- napari_cellseg3d/code_plugins/plugin_convert.py | 1 - napari_cellseg3d/config.py | 7 +------ 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 584ffd3b..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,7 +8,6 @@ from napari_cellseg3d.config import MODEL_LIST - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index d43625d8..6734b06f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,11 +4,11 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget + from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - from skimage.filters import thresholding from skimage.transform import resize # from skimage.measure import mesh_surface_area @@ -20,6 +20,7 @@ from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index f9d1c801..ed1a43df 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -13,7 +13,6 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index bb714f5d..a99924dc 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,6 +8,7 @@ import napari import numpy as np + from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -16,12 +17,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER From 1eb61e02d7c121c4c004e59cfc3adf53c6fdd6b3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 135/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 45c3be1c..3becfaca 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,3 @@ [pytest] qt_api=pyqt5 + From ff30837498695d63106452b872756cbe09c0e687 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 136/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- napari_cellseg3d/code_models/model_instance_seg.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 6734b06f..e1b2eb03 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -17,10 +17,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -71,6 +69,7 @@ def __init__( text_label="", parent=None, ), + ) self.sliders.append(getattr(self, widget)) @@ -384,11 +383,11 @@ def fill(lst, n=len(properties) - 1): fill([len(properties)]), ) + class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" def __init__(self, widget_parent=None): - super().__init__( name=WATERSHED, function=binary_watershed, @@ -434,13 +433,12 @@ class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" def __init__(self, widget_parent=None): - super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - # widget_parent=widget_parent, + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -466,7 +464,6 @@ class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" def __init__(self, widget_parent=None): - super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, @@ -585,5 +582,4 @@ def run_method(self, volume): VORONOI_OTSU: VoronoiOtsu, WATERSHED: Watershed, CONNECTED_COMP: ConnectedComponents, - } From 2d67857b945dbd30141896d927788c28d9a04120 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 137/577] Fix inference --- napari_cellseg3d/code_models/model_instance_seg.py | 1 + napari_cellseg3d/code_plugins/plugin_model_inference.py | 1 + napari_cellseg3d/config.py | 6 +++++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e1b2eb03..19d87a6a 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -528,6 +528,7 @@ def __init__(self, parent=None): """Contains the instance of the method, with its name as key""" self.instance_widgets = {} """Contains the lists of widgets for each methods, to show/hide""" + self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index fb6fb71c..e6fec55e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -556,6 +556,7 @@ def start(self): method=self.instance_widgets.methods[ self.instance_widgets.method_choice.currentText() ], + ) self.post_process_config = config.PostProcessConfig( diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index a99924dc..34382460 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -115,6 +115,10 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class InstanceSegConfig: @@ -126,7 +130,7 @@ class InstanceSegConfig: class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ From dba19e1e59daec0a1611c01706d5ebd83225602a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 138/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- napari_cellseg3d/code_models/model_instance_seg.py | 4 ++++ napari_cellseg3d/dev_scripts/artefact_labeling.py | 11 ++++++----- napari_cellseg3d/dev_scripts/correct_labels.py | 4 +--- napari_cellseg3d/dev_scripts/evaluate_labels.py | 2 +- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 19d87a6a..e279235e 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -50,6 +50,7 @@ def __init__( num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -519,6 +520,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -570,10 +572,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = self.methods[self.method_choice.currentText()] return method.run_method(volume) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..aaf345cf 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -12,7 +12,6 @@ # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) - """ New code by Yves Paychere Creates labels of artifacts in an image based on existing labels of neurons @@ -78,7 +77,7 @@ def make_labels( Parameters ---------- image : str - Path to image. + image array path_labels_out : str Path of the output labelled image. threshold_size : int, optional @@ -97,7 +96,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - # image = imread(image) + image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -107,6 +106,7 @@ def make_labels( image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( np.max(image_contrasted) - np.min(image_contrasted) ) + image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) @@ -126,7 +126,6 @@ def make_labels( image_contrasted.astype(np.float32), ) - def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters @@ -142,10 +141,12 @@ def select_image_by_labels(image, labels, path_image_out, label_values): """ # image = imread(image) # labels = imread(labels) + image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) + # select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) @@ -430,4 +431,4 @@ def create_artefact_labels_from_folder( # threshold_artefact_brightness_percent=20, # threshold_artefact_size_percent=1, # contrast_power=20, -# ) +# ) \ No newline at end of file diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2f079d09..4c52675c 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -18,7 +18,6 @@ # import sys # sys.path.append(str(Path(__file__) / "../../")) - """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -87,7 +86,6 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] - def ask_labels(unique_artefact, test=False): global returns returns = [] @@ -139,7 +137,6 @@ def ask_labels(unique_artefact, test=False): returns = [i_labels_to_add_tmp] print("close the napari window to continue") - def relabel( image_path, label_path, @@ -373,3 +370,4 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # gt_labels_path = str(im_path / "labels.tif") # # relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) + diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index ee9919b6..087a01bb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -503,4 +503,4 @@ def save_as_csv(results, path): # "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" # ) # -# evaluate_model_performance(labels, labels_model, visualize=True) +# evaluate_model_performance(labels, labels_model, visualize=True) \ No newline at end of file From 991b189e24b7bd6b82f1830c2c3c943065da08c6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 139/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- napari_cellseg3d/dev_scripts/evaluate_labels.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 087a01bb..e253eb2c 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -7,7 +7,6 @@ PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct - def evaluate_model_performance( labels, model_labels, @@ -47,7 +46,7 @@ def evaluate_model_performance( mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels, threshold_correct ) @@ -57,7 +56,7 @@ def evaluate_model_performance( # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) From e4c2a889d8230aca83eacb793cc06239f423ca0e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 140/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 - napari_cellseg3d/dev_scripts/correct_labels.py | 1 - napari_cellseg3d/dev_scripts/evaluate_labels.py | 6 ++++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index aaf345cf..69d6535d 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,5 +1,4 @@ import os - import napari import numpy as np import scipy.ndimage as ndimage diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 4c52675c..9fcb2a88 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -17,7 +17,6 @@ # import sys # sys.path.append(str(Path(__file__) / "../../")) - """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index e253eb2c..b74251f8 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,5 +1,7 @@ import napari import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm @@ -46,7 +48,7 @@ def evaluate_model_performance( mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels, threshold_correct ) @@ -56,7 +58,7 @@ def evaluate_model_performance( # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) From b7f571f141ae9e23e1a9ef0c1840f94df93413b6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 141/577] black --- napari_cellseg3d/code_models/model_instance_seg.py | 1 - napari_cellseg3d/code_plugins/plugin_model_inference.py | 1 - napari_cellseg3d/config.py | 2 ++ 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e279235e..b1d4d9b7 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -70,7 +70,6 @@ def __init__( text_label="", parent=None, ), - ) self.sliders.append(getattr(self, widget)) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index e6fec55e..fb6fb71c 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -556,7 +556,6 @@ def start(self): method=self.instance_widgets.methods[ self.instance_widgets.method_choice.currentText() ], - ) self.post_process_config = config.PostProcessConfig( diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 34382460..2f591621 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -115,11 +115,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class InstanceSegConfig: enabled: bool = False From 4135529725e8f4b364ed8443794ab865e976a7e8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 142/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 188 +++++++++++++++++- 1 file changed, 186 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b74251f8..6065520a 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,5 @@ import napari import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm @@ -287,6 +285,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons ####################### # Slower version that was used for debugging From e25b17741ecfcb9035c2b6dfea10530c9b6f39fa Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 143/577] Enfore pre-commit style --- napari_cellseg3d/config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 2f591621..96025082 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -122,11 +122,6 @@ class InstanceSegConfig: method: InstanceMethod = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: InstanceMethod = None - @dataclass class PostProcessConfig: From 921881d80d160c37d11d0506321d6d3f781d3e68 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 144/577] Removing dask-image --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df43b4fa..ea3db7fc 100644 --- a/.gitignore +++ b/.gitignore @@ -110,3 +110,4 @@ notebooks/full_plot.html !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif + From e3f921b658f6c4fac5a5e1dbd59bf894e1141e2a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 145/577] Fixed erroneous dtype conversion --- .../code_models/model_instance_seg.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index b1d4d9b7..412c87d7 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,13 +4,11 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget - from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.filters import thresholding -from skimage.transform import resize + # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread @@ -20,10 +18,6 @@ from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - - # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : @@ -86,6 +80,7 @@ def __init__( def run_method(self, image): raise NotImplementedError("Must be defined in child classes") + @dataclass class ImageStats: volume: List[float] @@ -126,6 +121,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -165,6 +161,8 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 + scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) + """ logger.debug( f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" @@ -429,6 +427,7 @@ def run_method(self, image): self.counters[1].value(), ) + class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" @@ -477,6 +476,7 @@ def __init__(self, widget_parent=None): ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 @@ -492,6 +492,7 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + ################ # For debugging # import napari @@ -536,6 +537,7 @@ def __init__(self, parent=None): def _build(self): group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) + try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) @@ -555,11 +557,11 @@ def _build(self): logger.debug( f"Caught runtime error {e}, most likely during testing" ) + self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: From 1b268759680475612f41cde70535fd983fa4d6f4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 146/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index cbfd97b2..7403f2b7 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,5 +1,4 @@ from pathlib import Path - import numpy as np from tifffile import imread From a03edb049ec4d59431e96785ea5c9382d46ddf91 Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 147/577] Update tox.ini --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index 40a2a7a0..87338cd8 100644 --- a/tox.ini +++ b/tox.ini @@ -37,6 +37,5 @@ deps = pytest-qt qtpy ; pyopencl[pocl] -; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From c12884ccce91e73a3f4e82b433c24595888bfe06 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:06:43 +0200 Subject: [PATCH 148/577] Added new pre-commit hooks --- .pre-commit-config.yaml | 1 - pyproject.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..d6bdc58e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: -# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort diff --git a/pyproject.toml b/pyproject.toml index 253af197..803338e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ dev = [ "ruff", "tuna", "pre-commit", - ] docs = [ "sphinx", From 74630666c2b4ce2b3ce4269cffe27164b6d156d7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 149/577] Run full suite of pre-commit hooks --- napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/code_models/model_instance_seg.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 412c87d7..c72bafe9 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,8 +15,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -492,7 +492,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari From 9f228f2805e9abf74b64299c274f24970898f019 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 10:53:03 +0200 Subject: [PATCH 150/577] Enforce style --- .gitignore | 1 - napari_cellseg3d/_tests/pytest.ini | 1 - napari_cellseg3d/_tests/test_plugin_utils.py | 1 + napari_cellseg3d/config.py | 3 +-- napari_cellseg3d/dev_scripts/artefact_labeling.py | 5 +++-- napari_cellseg3d/dev_scripts/correct_labels.py | 3 ++- napari_cellseg3d/dev_scripts/evaluate_labels.py | 4 +++- napari_cellseg3d/interface.py | 3 +-- napari_cellseg3d/utils.py | 4 ++-- 9 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index ea3db7fc..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -110,4 +110,3 @@ notebooks/full_plot.html !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif - diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 3becfaca..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,3 +1,2 @@ [pytest] qt_api=pyqt5 - diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 7403f2b7..cbfd97b2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,4 +1,5 @@ from pathlib import Path + import numpy as np from tifffile import imread diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 96025082..ab3dba39 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,8 +8,8 @@ import napari import numpy as np - from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -122,7 +122,6 @@ class InstanceSegConfig: method: InstanceMethod = None - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 69d6535d..90048a60 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ import os + import napari import numpy as np import scipy.ndimage as ndimage @@ -125,6 +126,7 @@ def make_labels( image_contrasted.astype(np.float32), ) + def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters @@ -145,7 +147,6 @@ def select_image_by_labels(image, labels, path_image_out, label_values): imwrite(path_image_out, image.astype(np.float32)) - # select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) @@ -430,4 +431,4 @@ def create_artefact_labels_from_folder( # threshold_artefact_brightness_percent=20, # threshold_artefact_size_percent=1, # contrast_power=20, -# ) \ No newline at end of file +# ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 9fcb2a88..aacf08f8 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -85,6 +85,7 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] + def ask_labels(unique_artefact, test=False): global returns returns = [] @@ -136,6 +137,7 @@ def ask_labels(unique_artefact, test=False): returns = [i_labels_to_add_tmp] print("close the napari window to continue") + def relabel( image_path, label_path, @@ -369,4 +371,3 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # gt_labels_path = str(im_path / "labels.tif") # # relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) - diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 6065520a..bd2f0768 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -7,6 +7,7 @@ PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + def evaluate_model_performance( labels, model_labels, @@ -285,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### @@ -688,4 +690,4 @@ def save_as_csv(results, path): # "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" # ) # -# evaluate_model_performance(labels, labels_model, visualize=True) \ No newline at end of file +# evaluate_model_performance(labels, labels_model, visualize=True) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 90b102c4..ff3af55c 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1049,13 +1049,12 @@ def __init__( self.label = make_label(name=label) self.valueChanged.connect(self._update_step) - def _update_step(self): #FIXME check divide_factor + def _update_step(self): # FIXME check divide_factor if self.value() < 0.9: self.setSingleStep(0.01) else: self.setSingleStep(0.1) - @property def tooltips(self): return self.toolTip() diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 02a5865e..1ddbe67d 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,9 +2,8 @@ import warnings from datetime import datetime from pathlib import Path + import numpy as np -from pandas import DataFrame -from pandas import Series from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -274,6 +273,7 @@ def annotation_to_input(label_ermito): anno = normalize_x(anno[np.newaxis, :, :, :]) return anno + # def check_csv(project_path, ext): # if not Path(Path(project_path) / Path(project_path).name).is_file(): # cols = [ From 393de28b68557a785a8f432bee549aa8aa0bfb9b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 14:29:40 +0200 Subject: [PATCH 151/577] Documentation update, crop contrast fix --- docs/res/guides/cropping_module_guide.rst | 6 +++--- docs/res/guides/utils_module_guide.rst | 10 +++++++++- docs/res/welcome.rst | 21 +++++++++++++------- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/docs/res/guides/cropping_module_guide.rst b/docs/res/guides/cropping_module_guide.rst index a862ffff..89cbb39a 100644 --- a/docs/res/guides/cropping_module_guide.rst +++ b/docs/res/guides/cropping_module_guide.rst @@ -33,9 +33,9 @@ If you'd like to change the size of the volume, change the parameters as previou Creating new layers --------------------------------- -To "zoom in" your volume, you can use the "Create new layers" checkbox to make a new layer not controlled by the plugin next -time you hit Start. This way, you can first select your region of interest by using the tool as described above, -the enable the option, select the cropped layer, and define a smaller crop size to have easier access to your region of interest. +To "zoom in" your volume, you can use the "Create new layers" checkbox to make a new cropping layer controlled by the sliders +next time you hit Start. This way, you can first select your region of interest by using the tool as described above, +then enable the option, select the cropped region produced before as the input layer, and define a smaller crop size in order to crop within your region of interest. Interface & functionalities --------------------------------------------------------------- diff --git a/docs/res/guides/utils_module_guide.rst b/docs/res/guides/utils_module_guide.rst index 407ae710..64e8a3ce 100644 --- a/docs/res/guides/utils_module_guide.rst +++ b/docs/res/guides/utils_module_guide.rst @@ -4,13 +4,21 @@ Label conversion utility guide ================================== This utility will let you convert labels to various different formats. + You will have to specify the results directory for saving; afterwards you can run each action on a folder or on the currently selected layer. You can : +* Crop 3D volumes : + Please refer to :ref:`cropping_module_guide` for a guide on using the cropping utility. + * Convert to instance labels : - This will convert 0/1 semantic labels to instance label, with a unique ID for each object using the watershed method. + This will convert 0/1 semantic labels to instance label, with a unique ID for each object. + The available methods for this are : + * Connected components : simple method that will assign a unique ID to each connected component. Does not work well for touching objects (objects will often be fused), works for anisotropic volumes. + * Watershed : method based on topographic maps. Works well for touching objects and anisotropic volumes; touching objects may be fused. + * Voronoi-Otsu : method based on Voronoi diagrams. Works well for touching objects but only for isotropic volumes. * Convert to semantic labels : This will convert instance labels with unique IDs per object into 0/1 semantic labels, for example for training. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index d2f2c0f0..892549a8 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -38,22 +38,28 @@ You can install `napari-cellseg3d` via [pip]: ``pip install napari-cellseg3d`` - For local installation, please run: +For local installation after cloning, please run in the CellSeg3D folder: ``pip install -e .`` Requirements -------------------------------------------- +.. note:: + A **CUDA-capable GPU** is not needed but **very strongly recommended**, especially for training and possibly inference. + .. important:: - A **CUDA-capable GPU** is not needed but **very strongly recommended**, especially for training. + This package requires you have napari installed with PyQt5 or PySide2 first. + If you do not have a Qt backend you can use : -This package requires you have napari installed first. + ``pip install napari-cellseg3d[all]`` + to install PyQt5 by default. -It also depends on PyTorch and some optional dependencies of MONAI. These come in the pip package above, but if +It also depends on PyTorch and some optional dependencies of MONAI. These come in the pip package as requirements, but if you need further assistance see below. * For help with PyTorch, please see `PyTorch's website`_ for installation instructions, with or without CUDA depending on your hardware. + Depending on your setup, you might wish to install torch first. * If you get errors from MONAI regarding missing readers, please see `MONAI's optional dependencies`_ page for instructions on getting the readers required by your images. @@ -70,14 +76,13 @@ To use the plugin, please run: Then go into Plugins > napari-cellseg3d, and choose which tool to use: - - **Review**: This module allows you to review your labels, from predictions or manual labeling, and correct them if needed. It then saves the status of each file in a csv, for easier monitoring - **Inference**: This module allows you to use pre-trained segmentation algorithms on volumes to automatically label cells - **Training**: This module allows you to train segmentation algorithms from labeled volumes - **Utilities**: This module allows you to use several utilities, e.g. to crop your volumes and labels, compute prediction scores or convert labels - **Help/About...** : Quick access to version info, Github page and docs -See above for links to detailed guides regarding the usage of the modules. +See the documentation for links to detailed guides regarding the usage of the modules. Acknowledgments & References --------------------------------------------- @@ -98,6 +103,7 @@ This plugin mainly uses the following libraries and software: * `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase +* A custom re-implementation of the `WNet model`_ by Xia and Kulis [#]_ .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ .. _Wyss Center: https://wysscenter.ch/ @@ -107,10 +113,11 @@ This plugin mainly uses the following libraries and software: .. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets .. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype - +.. _WNet model: https://arxiv.org/abs/1711.08506 .. rubric:: References .. [#] Mapping mesoscale axonal projections in the mouse brain using a 3D convolutional network, Friedmann et al., 2020 ( https://pnas.org/cgi/doi/10.1073/pnas.1918465117 ) .. [#] The mesoSPIM initiative: open-source light-sheet microscopes for imaging cleared tissue, Voigt et al., 2019 ( https://doi.org/10.1038/s41592-019-0554-0 ) .. [#] MONAI Project website ( https://monai.io/ ) +.. [#] W-Net: A Deep Model for Fully Unsupervised Image Segmentation, Xia and Kulis, 2018 ( https://arxiv.org/abs/1711.08506 ) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index fa4857aa..cb149b52 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -306,7 +306,7 @@ def _start(self): else: self.image_layer1.opacity = 0.7 self.image_layer1.colormap = "inferno" - self.image_layer1.contrast_limits = [200, 1000] # TODO generalize + # self.image_layer1.contrast_limits = [200, 1000] # TODO generalize self.image_layer1.refresh() From 1a39a74062bea2ccccfaf9e932611e8f2e16d753 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 17:41:05 +0200 Subject: [PATCH 152/577] Fixes and improvements - More CRF info - Added error handling to scheduler rate - Added ETA to training - Updated padding warning trigger size --- napari_cellseg3d/code_models/crf.py | 30 ++++++++++------ napari_cellseg3d/code_models/model_workers.py | 34 ++++++++++++++----- .../code_models/models/model_VNet.py | 2 +- napari_cellseg3d/code_plugins/plugin_crf.py | 6 ++++ .../code_plugins/plugin_model_inference.py | 3 ++ .../code_plugins/plugin_model_training.py | 6 ++-- napari_cellseg3d/utils.py | 6 ++-- 7 files changed, 61 insertions(+), 26 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index a0146a5e..1b8dce28 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -33,6 +33,7 @@ from napari.qt.threading import GeneratorWorker from napari_cellseg3d.config import CRFConfig +from napari_cellseg3d.utils import LOGGER as logger __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ @@ -52,12 +53,16 @@ ] -def correct_shape_for_crf(image): - if len(image.shape) == 4: +def correct_shape_for_crf(image, desired_dims=4): + if len(image.shape) == desired_dims: return image - if len(image.shape) > 4: + if len(image.shape) > desired_dims: + if image.shape[0] > 1: + raise ValueError( + f"Image shape {image.shape} might have several channels" + ) image = np.squeeze(image, axis=0) - if len(image.shape) < 4: + if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) return correct_shape_for_crf(image) @@ -146,7 +151,7 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): ) -def crf_with_config(image, prob, config: CRFConfig = None): +def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info): if config is None: config = CRFConfig() if image.shape[-3:] != prob.shape[-3:]: @@ -156,6 +161,12 @@ def crf_with_config(image, prob, config: CRFConfig = None): ) image = correct_shape_for_crf(image) + prob = correct_shape_for_crf(prob) + + if log is not None: + log("Running CRF post-processing step") + log(f"Image shape : {image.shape}") + log(f"Labels shape : {prob.shape}") return crf( image, @@ -196,15 +207,12 @@ def _run_crf_job(self): raise ImportError("pydensecrf is not installed.") for image, labels in zip(self.images, self.labels): - if len(image.shape) == 3: - image = np.expand_dims(image, axis=0) - - if len(labels.shape) == 3: - labels = np.expand_dims(labels, axis=0) - if image.shape[-3:] != labels.shape[-3:]: raise ValueError("Image and labels must have the same shape.") + image = correct_shape_for_crf(image) + labels = correct_shape_for_crf(labels) + yield crf( image, labels, diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 06285aea..4ab0c936 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1,4 +1,5 @@ import platform +import time import typing as t from dataclasses import dataclass from math import ceil @@ -599,7 +600,7 @@ def save_image( filetype = self.config.filetype else: original_filename = "_" - filetype = "" + filetype = ".tif" time = utils.get_date_time() @@ -713,12 +714,9 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): ) def run_crf(self, image, labels, image_id=0): - self.log(f"IMAGE SHAPE : {image.shape}") - self.log(f"LABEL SHAPE : {labels.shape}") - try: crf_results = crf_with_config( - image, labels, config=self.config.crf_config + image, labels, config=self.config.crf_config, log=self.log ) self.save_image( crf_results, i=image_id, additional_info="CRF", from_layer=True @@ -1153,6 +1151,8 @@ def train(self): weights_config = self.config.weights_info deterministic_config = self.config.deterministic_config + start_time = time.time() + try: if deterministic_config.enabled: set_determinism( @@ -1365,14 +1365,23 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) + + factor = self.config.scheduler_factor + if factor >= 1.0: + self.log(f"Warning : scheduler factor is {factor} >= 1.0") + self.log("Setting it to 0.5") + factor = 0.5 + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode="min", - factor=self.config.scheduler_factor, + factor=factor, patience=self.config.scheduler_patience, verbose=VERBOSE_SCHEDULER, ) - dice_metric = DiceMetric(include_background=True, reduction="mean") + dice_metric = DiceMetric( + include_background=False, reduction="mean" + ) best_metric = -1 best_metric_epoch = -1 @@ -1468,6 +1477,15 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] + self.log( + "ETA: " + + str( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 + ) + + "minutes" + ) if ( (epoch + 1) % self.config.validation_interval == 0 @@ -1491,7 +1509,7 @@ def train(self): overlap=0.25, sw_device=self.config.device, device=self.config.device, - progress=True, + progress=False, ) except Exception as e: self.raise_error(e, "Error during validation") diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 41554e80..7aa6476e 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=1, **kwargs): + def __init__(self, in_channels=1, out_channels=2, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index cbdacf3a..7ac605e9 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -178,6 +178,11 @@ def _build(self): def make_config(self): return self.crf_params_widget.make_config() + def print_config(self): + logger.info("CRF config:") + for item in self.make_config().__dict__.items(): + logger.info(f"{item[0]}: {item[1]}") + def _check_ready(self): if len(self.label_layer_loader.layer_list) < 1: logger.warning("No label layer loaded") @@ -272,6 +277,7 @@ def _on_start(self): def _on_finish(self): self.worker = None + self.start_button.setText("Start") def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 03381779..9c24cf1b 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -802,6 +802,9 @@ def on_yield(self, result: InferenceResult): # ) if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) viewer.add_image( result.crf_results, name=f"CRF_results_image_{image_id}", diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 88991f43..86d1d317 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -846,7 +846,7 @@ def start(self): loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), @@ -982,7 +982,7 @@ def on_yield(self, report: TrainingReport): layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(layer) else: @@ -993,7 +993,7 @@ def on_yield(self, report: TrainingReport): new_layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(new_layer) self.result_layers[i].data = report.images[i] diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 171c20f0..f84dbe8b 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -12,8 +12,8 @@ LOGGER = logging.getLogger(__name__) ############### # Global logging level setting -# LOGGER.setLevel(logging.DEBUG) -LOGGER.setLevel(logging.INFO) +LOGGER.setLevel(logging.DEBUG) +# LOGGER.setLevel(logging.INFO) ############### """ utils.py @@ -312,7 +312,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): pad = 2**n n += 1 - if pad >= 256: + if pad >= 1024: LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." From 85cc3453aec87fe981858b1357766c25c9ec3998 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 3 May 2023 09:57:34 +0200 Subject: [PATCH 153/577] Fixes and channel labeling prototype --- napari_cellseg3d/code_models/model_workers.py | 33 +++-- .../extract_extra_channels_labels.py | 124 ++++++++++++++++++ 2 files changed, 143 insertions(+), 14 deletions(-) create mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 4ab0c936..7b51bd57 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -547,12 +547,14 @@ def create_inference_result( "A layer's ID should always be 0 (default value)" ) extra_dims = len(semantic_labels.shape) - 3 - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + if semantic_labels is not None: + semantic_labels = np.swapaxes( + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + if crf_results is not None: + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims + ) return InferenceResult( image_id=i + 1, @@ -1457,6 +1459,12 @@ def train(self): optimizer.zero_grad() outputs = model(inputs) # self.log(f"Output dimensions : {outputs.shape}") + if outputs.shape[1] > 1: + outputs = outputs[ + :, 1:, :, : + ] # FIXME fix channel number + if len(outputs.shape) < 4: + outputs = outputs.unsqueeze(0) loss = self.config.loss_function(outputs, labels) loss.backward() optimizer.step() @@ -1477,15 +1485,12 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] - self.log( - "ETA: " - + str( - (time.time() - start_time) - * (self.config.max_epochs / (epoch + 1) - 1) - / 60 - ) - + "minutes" + eta = ( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 ) + self.log("ETA: " + f"{eta:.2f}" + " minutes") if ( (epoch + 1) % self.config.validation_interval == 0 diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py new file mode 100644 index 00000000..2bd0a536 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -0,0 +1,124 @@ +import numpy as np +from skimage.filters import threshold_otsu +from skimage.segmentation import expand_labels +from tqdm import tqdm + + +def extract_labels_from_channels( + nucleus_labels: np.array, + extra_channels: list, + radius: int = 4, + threshold_factor=2, + viewer=None, +): + """ + Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. + Args: + nucleus_labels (np.array): labels for the nuclei + extra_channels (list): channels arrays to extract labels from + radius: radius in which the approximation is made + + Returns: + A list of extracted labels for each extra channel + """ + labeled_channels = {} + + contrasted_channels = [] + for channel in extra_channels: + channel = (channel - np.min(channel)) / ( + np.max(channel) - np.min(channel) + ) + threshold_brightness = threshold_otsu(channel) * threshold_factor + channel_contrasted = np.where( + channel > threshold_brightness, channel, 0 + ) + contrasted_channels.append(channel_contrasted) + if viewer is not None: + viewer.add_image( + channel_contrasted, + name="channel_contrasted", + colormap="viridis", + ) + for label_id in tqdm(np.unique(nucleus_labels)): + if label_id == 0: + continue + label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + expanded = expand_labels(label_nucleus, distance=radius) + for i, channel in enumerate(contrasted_channels): + label_contrasted = np.where(expanded != 0, channel, 0) + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels[ + f"label_{label_id}_channel_{i+1}" + ] = np.count_nonzero(labeled_channel) + if np.count_nonzero(labeled_channel) > 0 and viewer is not None: + print(np.count_nonzero(labeled_channel)) + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) + + return labeled_channels + + +if __name__ == "__main__": + from pathlib import Path + + import napari + import pandas as pd + from tifffile import imread + + image_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + ) + # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + nuclei_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" + ) + extra_channels_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" + ) + extra_channels = [ + imread(str(path)) + for path in extra_channels_path.glob( + "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" + ) + ] + labels = imread(str(image_path)) + viewer = napari.Viewer() + + shift = 0 + viewer.add_image( + imread(str(nuclei_path))[ + shift : 32 + shift, shift : 32 + shift, shift : 32 + shift + ], + name="nuclei", + ) + viewer.add_labels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + [ + viewer.add_image( + channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + for channel in extra_channels + ] + + labeled_channels = extract_labels_from_channels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], + [ + c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + for c in extra_channels + ], + radius=4, + viewer=viewer, + ) + table = pd.DataFrame( + labeled_channels.items(), columns=["name", "pixels count"] + ) + print(table) + # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] + # expanded = expand_labels(labels, 4) + # viewer.add_labels(expanded) + napari.run() From 28a4c2a28e10abebf8c7290a454f4ad4adae17a7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 09:18:42 +0200 Subject: [PATCH 154/577] Fixes - Fixed multi-channel instance and csv stats - Fixed rotation of inference outputs - Raised max crop size --- napari_cellseg3d/code_models/model_workers.py | 76 +++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 4 +- .../code_plugins/plugin_model_inference.py | 79 +++++++++---------- .../extract_extra_channels_labels.py | 64 +++++++++------ napari_cellseg3d/interface.py | 54 ++++++++----- napari_cellseg3d/utils.py | 6 ++ 6 files changed, 161 insertions(+), 122 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 7b51bd57..5529cfef 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -546,15 +546,15 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - extra_dims = len(semantic_labels.shape) - 3 + if semantic_labels is not None: - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) + semantic_labels = utils.correct_rotation(semantic_labels) if crf_results is not None: - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + crf_results = utils.correct_rotation(crf_results) + if instance_labels is not None: + instance_labels = utils.correct_rotation( + instance_labels + ) # TODO(cyril) check if correct return InferenceResult( image_id=i + 1, @@ -580,10 +580,6 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): semantic_labels, i + 1, ) - if from_layer: - instance_labels = np.swapaxes( - instance_labels, 0, 2 - ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -609,10 +605,11 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"{additional_info}_Prediction_{i+1}" + + f"{additional_info}" + + f"Prediction_{i+1}" + original_filename + self.config.model_info.name - + f"_{time}_" + + f"_{time}" + filetype ) try: @@ -639,18 +636,20 @@ def aniso_transform(self, image): return image def instance_seg( - self, to_instance, image_id=0, original_filename="layer", channel=None + self, semantic_labels, image_id=0, original_filename="layer" ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) - if channel is not None: - channel_id = f"_{channel}" + if len(semantic_labels.shape) == 4: + instance_labels = np.array( + [method.run_method(ch) for ch in semantic_labels] + ) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") else: - channel_id = "" + instance_labels = method.run_method(image=semantic_labels) if self.config.filetype == "": filetype = "" @@ -662,7 +661,6 @@ def instance_seg( + "/" + f"Instance_seg_labels_{image_id}_" + original_filename - + channel_id + "_" + self.config.model_info.name + f"_{utils.get_date_time()}" @@ -721,7 +719,10 @@ def run_crf(self, image, labels, image_id=0): image, labels, config=self.config.crf_config, log=self.log ) self.save_image( - crf_results, i=image_id, additional_info="CRF", from_layer=True + crf_results, + i=image_id, + additional_info="CRF_", + from_layer=True, ) return crf_results except ValueError as e: @@ -729,14 +730,17 @@ def run_crf(self, image, labels, image_id=0): return None def stats_csv(self, instance_labels): - if self.config.compute_stats: - stats = volume_stats(instance_labels) - return stats - - # except ValueError as e: - # self.log(f"Error occurred during stats computing : {e}") - # return None - else: + try: + if self.config.compute_stats: + if len(instance_labels.shape) == 4: + stats = [volume_stats(c) for c in instance_labels] + else: + stats = [volume_stats(instance_labels)] + return stats + else: + return None + except ValueError as e: + self.log(f"Error occurred during stats computing : {e}") return None def inference_on_layer(self, image, model, post_process_transforms): @@ -754,15 +758,9 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels_results = [] - stats_results = [] - - for channel in out: - instance_labels, stats = self.get_instance_result( - channel, from_layer=True - ) - instance_labels_results.append(instance_labels) - stats_results.append(stats) + instance_labels, stats = self.get_instance_result( + semantic_labels=out, from_layer=True + ) if self.config.use_crf: crf_results = self.run_crf(image, out) @@ -771,10 +769,10 @@ def inference_on_layer(self, image, model, post_process_transforms): return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels_results, + instance_labels=instance_labels, crf_results=crf_results, from_layer=True, - stats=stats_results, + stats=stats, ) # @thread_worker(connect={"errored": self.raise_error}) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 7363d91c..3ae8a4eb 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -80,7 +80,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.results_filewidget.check_ready() self.crop_size_widgets = ui.IntIncrementCounter.make_n( - 3, 1, 1000, DEFAULT_CROP_SIZE + 3, 1, 10000, DEFAULT_CROP_SIZE ) self.crop_size_labels = [ ui.make_label("Size in " + axis + " of cropped volume :", self) @@ -310,7 +310,7 @@ def _start(self): else: self.image_layer1.opacity = 0.7 self.image_layer1.colormap = "inferno" - self.image_layer1.contrast_limits = [200, 1000] # TODO generalize + # self.image_layer1.contrast_limits = [200, 1000] # TODO generalize self.image_layer1.refresh() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 9c24cf1b..57bb740f 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -140,7 +140,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.thresholding_slider = ui.Slider( - lower=1, default=config.PostProcessConfig().thresholding.threshold_value * 100, divide_factor=100.0, @@ -435,10 +434,10 @@ def _build(self): self.anisotropy_wdgt, # anisotropy self.thresholding_checkbox, self.thresholding_slider.container, # thresholding - self.use_instance_choice, - self.instance_widgets, self.use_crf, self.crf_widgets, + self.use_instance_choice, + self.instance_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -754,61 +753,61 @@ def on_yield(self, result: InferenceResult): name=f"pred_{image_id}_{model_name}", opacity=0.8, ) + if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) if ( len(result.instance_labels) > 0 and self.worker_config.post_process_config.instance.enabled ): - for i, labels in enumerate(result.instance_labels): - # labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(result.instance_labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(result.instance_labels, name=name) from napari_cellseg3d.utils import LOGGER as log - log.debug(f"len stats : {len(result.stats)}") + if result.stats is not None and isinstance( + result.stats, list + ): + log.debug(f"len stats : {len(result.stats)}") - for i, stats in enumerate(result.stats): - # stats = result.stats + for i, stats in enumerate(result.stats): + # stats = result.stats - if ( - self.worker_config.compute_stats - and stats is not None - ): - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances in channel {i} : {stats.number_objects[0]}" - ) + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) - - if result.crf_results is not None: - logger.debug( - f"CRF results shape : {result.crf_results.shape}" - ) - viewer.add_image( - result.crf_results, - name=f"CRF_results_image_{image_id}", - colormap="viridis", - ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py index 2bd0a536..70ee10b6 100644 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -4,8 +4,8 @@ from tqdm import tqdm -def extract_labels_from_channels( - nucleus_labels: np.array, +def extract_labels_from_channels( # TODO add separate channels results + nuclei_labels: np.array, extra_channels: list, radius: int = 4, threshold_factor=2, @@ -14,15 +14,14 @@ def extract_labels_from_channels( """ Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. Args: - nucleus_labels (np.array): labels for the nuclei + nuclei_labels (np.array): labels for the nuclei extra_channels (list): channels arrays to extract labels from radius: radius in which the approximation is made Returns: A list of extracted labels for each extra channel """ - labeled_channels = {} - + labeled_channels = [] contrasted_channels = [] for channel in extra_channels: channel = (channel - np.min(channel)) / ( @@ -39,31 +38,54 @@ def extract_labels_from_channels( name="channel_contrasted", colormap="viridis", ) - for label_id in tqdm(np.unique(nucleus_labels)): + for label_id in tqdm(np.unique(nuclei_labels)): if label_id == 0: continue - label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) expanded = expand_labels(label_nucleus, distance=radius) + restricted = np.where(expanded != 0, nuclei_labels, 0) + overlap = np.where(restricted != label_id, restricted, 0) + for i, channel in enumerate(contrasted_channels): label_contrasted = np.where(expanded != 0, channel, 0) - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels[ - f"label_{label_id}_channel_{i+1}" - ] = np.count_nonzero(labeled_channel) - if np.count_nonzero(labeled_channel) > 0 and viewer is not None: - print(np.count_nonzero(labeled_channel)) - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) + if overlap.any() != 0: + max_labeled = 0 + for overlap_id in np.unique(overlap): + if overlap_id == 0: + continue + assigned_pixels = np.count_nonzero( + np.where(overlap == overlap_id, channel, 0) + ) + if assigned_pixels > max_labeled: + max_labeled = assigned_pixels + max_label_id = overlap_id + if label_id != max_label_id: + labeled_channels.append( + np.zeros_like(label_contrasted) + ) + else: + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels.append(labeled_channel) + if ( + np.count_nonzero(labeled_channel) > 0 + and viewer is not None + ): + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) - return labeled_channels + cat_labels = np.zeros_like(nuclei_labels) + for labels in np.unique(labeled_channels): + if labels == 0: + continue + cat_labels += np.where(labels != 0, labels, 0) + return cat_labels if __name__ == "__main__": from pathlib import Path import napari - import pandas as pd from tifffile import imread image_path = ( @@ -114,10 +136,8 @@ def extract_labels_from_channels( radius=4, viewer=viewer, ) - table = pd.DataFrame( - labeled_channels.items(), columns=["name", "pixels count"] - ) - print(table) + + viewer.add_labels(labeled_channels) # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] # expanded = expand_labels(labels, 4) # viewer.add_labels(expanded) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 209b093e..574ef23a 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -470,6 +470,11 @@ def __init__( ): super().__init__(orientation, parent) + if upper <= lower: + raise ValueError( + "The minimum value cannot be below the maximum one" + ) + self.setMaximum(upper) self.setMinimum(lower) self.setSingleStep(step) @@ -545,23 +550,29 @@ def _warn_outside_bounds(self, default): def _update_slider(self): """Update slider when value is changed""" - if self._value_label.text() == "": - return + try: + if self._value_label.text() == "": + return - value = float(self._value_label.text()) * self._divide_factor + value = float(self._value_label.text()) * self._divide_factor - if value < self.minimum(): - self.slider_value = self.minimum() - return - if value > self.maximum(): - self.slider_value = self.maximum() - return + if value < self.minimum(): + self.slider_value = self.minimum() + return + if value > self.maximum(): + self.slider_value = self.maximum() + return - self.slider_value = value + self.slider_value = value + except Exception as e: + logger.error(e) def _update_value_label(self): """Update label, to connect to when slider is dragged""" - self._value_label.setText(str(self.value_text)) + try: + self._value_label.setText(str(self.value_text)) + except Exception as e: + logger.error(e) @property def tooltips(self): @@ -597,16 +608,21 @@ def value_text(self): def slider_value(self, value: int): """Set a value (int) divided by self._divide_factor""" if value < self.minimum() or value > self.maximum(): - raise ValueError( - f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + logger.error( + ValueError( + f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + ) ) - self.setValue(int(value)) - - divided = value / self._divide_factor - if self._divide_factor == 1.0: - divided = int(divided) - self._value_label.setText(str(divided)) + try: + self.setValue(int(value)) + + divided = value / self._divide_factor + if self._divide_factor == 1.0: + divided = int(divided) + self._value_label.setText(str(divided)) + except Exception as e: + logger.error(e) class AnisotropyWidgets(QWidget): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index f84dbe8b..49aea584 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -202,6 +202,12 @@ def dice_coeff(y_true, y_pred): return score +def correct_rotation(image): + """Rotates the exes 0 and 2 in [DHW] section of image array""" + extra_dims = len(image) - 3 + return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, From d6767107ff7ba7914bdf40d934b1ba818ab51e9b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 14:42:02 +0200 Subject: [PATCH 155/577] Update plugin_model_inference.py --- napari_cellseg3d/code_plugins/plugin_model_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 57bb740f..78cf4438 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -762,9 +762,8 @@ def on_yield(self, result: InferenceResult): name=f"CRF_results_image_{image_id}", colormap="viridis", ) - if ( - len(result.instance_labels) > 0 + result.instance_labels is not None and self.worker_config.post_process_config.instance.enabled ): method_name = ( From 47ba0811551adccde4fdd9ec99004fb47dc87469 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 6 May 2023 09:56:17 +0200 Subject: [PATCH 156/577] Update plugin_crop.py --- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 3ae8a4eb..e2189a15 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -49,7 +49,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.label_layer_loader.layer_list.label.setText("Image 2") self.crop_second_image_choice = ui.CheckBox( - "Crop another\nimage simultaneously", + "Crop another\nimage/label simultaneously", ) self.crop_second_image_choice.toggled.connect( self._toggle_second_image_io_visibility From 8a56ca57e9c8acc515ccc1ee0da6256c2054df5b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 10:16:58 +0200 Subject: [PATCH 157/577] Fixed patch_func sample number mismatch --- napari_cellseg3d/code_models/model_workers.py | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 5529cfef..3ecc47f0 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1229,30 +1229,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if do_sampling: - sample_loader = Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=self.config.num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), - ), - EnsureTyped(keys=["image", "label"]), - ] - ) if self.config.do_augmentation: train_transforms = ( @@ -1284,6 +1260,31 @@ def train(self): ] ) # self.log("Loading dataset...\n") + def get_loader_func(num_samples): + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) + ), + ), + EnsureTyped(keys=["image", "label"]), + ] + ) + if do_sampling: # if there is only one volume, split samples # TODO(cyril) : maybe implement something in user config to toggle this behavior @@ -1296,11 +1297,17 @@ def train(self): self.config.num_samples * (1 - self.config.validation_percent) ) + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) else: num_train_samples = ( num_val_samples ) = self.config.num_samples + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) + + logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" @@ -1310,14 +1317,14 @@ def train(self): train_ds = PatchDataset( data=self.train_files, transform=train_transforms, - patch_func=sample_loader, + patch_func=sample_loader_train, samples_per_image=num_train_samples, ) logger.debug("val_ds") val_ds = PatchDataset( data=self.val_files, transform=val_transforms, - patch_func=sample_loader, + patch_func=sample_loader_eval, samples_per_image=num_val_samples, ) From 4745e183004b5e2a26d0a4c6a7d7d5299f846b9f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:08:52 +0200 Subject: [PATCH 158/577] Testing relabel tools --- napari_cellseg3d/dev_scripts/correct_labels.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2ab60332..9862c3fa 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -367,8 +367,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") +# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") # -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# image_path = str(im_path / "volumes/images.tif") +# gt_labels_path = str(im_path / "labels/testing_im.tif") +# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) From d148b9ece49d7a333bf9b3d435fbb782b41099dc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:38:45 +0200 Subject: [PATCH 159/577] Fixes in inference --- napari_cellseg3d/code_models/model_workers.py | 2 ++ napari_cellseg3d/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 3ecc47f0..7c52cd26 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -503,6 +503,8 @@ def model_output_wrapper(inputs): sw_device=self.config.device, device=dataset_device, overlap=window_overlap, + mode="gaussian", + sigma_scale=0.01, progress=True, ) except Exception as e: diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 49aea584..8ca6f146 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -204,7 +204,7 @@ def dice_coeff(y_true, y_pred): def correct_rotation(image): """Rotates the exes 0 and 2 in [DHW] section of image array""" - extra_dims = len(image) - 3 + extra_dims = len(image.shape) - 3 return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) From 2d99c79d2a69d7266c89ed96d4661da8b47b341c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 14:48:14 +0200 Subject: [PATCH 160/577] add model template + fix test + wnet loading opti - test fixes - changed crf input reqs - adapted instance seg for several channels --- napari_cellseg3d/_tests/test_models.py | 10 ++- .../_tests/test_plugin_inference.py | 11 ++-- napari_cellseg3d/_tests/test_training.py | 11 ++-- napari_cellseg3d/code_models/crf.py | 11 ++-- .../code_models/model_instance_seg.py | 29 ++++++++- napari_cellseg3d/code_models/model_workers.py | 62 +++++++++---------- .../code_models/models/TEMPLATE_model.py | 20 ++++++ .../code_models/models/model_SwinUNetR.py | 13 +++- .../code_models/models/model_WNet.py | 19 ++++++ .../code_plugins/plugin_convert.py | 2 +- 10 files changed, 129 insertions(+), 59 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/TEMPLATE_model.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 1fc15872..35af8c76 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -15,6 +15,8 @@ def test_correct_shape_for_crf(): def test_model_list(): for model_name in MODEL_LIST.keys(): + # if model_name=="test": + # continue dims = 128 test = MODEL_LIST[model_name]( input_img_size=[dims, dims, dims], @@ -39,18 +41,20 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) - # assert res > 0 + assert 0 <= res <= 1 def test_crf(qtbot): dims = 8 mock_image = np.random.rand(1, dims, dims, dims) mock_label = np.random.rand(2, dims, dims, dims) - - crf = CRFWorker(mock_image, mock_label) + assert len(mock_label.shape) == 4 + crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) + assert len(result.shape) == 4 + assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] crf.yielded.connect(on_yield) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 66c50fba..3dafeabc 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,9 +3,10 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_inference(make_napari_viewer, qtbot): @@ -28,9 +29,9 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.setCurrentIndex(-1) # widget.start() # takes too long on Github Actions # assert widget.worker is not None diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 21731ba1..921a6d26 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,9 +2,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_training(make_napari_viewer, qtbot): @@ -32,9 +33,9 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) # widget.start() # assert widget.worker is not None diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 1b8dce28..21caf35f 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -57,10 +57,10 @@ def correct_shape_for_crf(image, desired_dims=4): if len(image.shape) == desired_dims: return image if len(image.shape) > desired_dims: - if image.shape[0] > 1: - raise ValueError( - f"Image shape {image.shape} might have several channels" - ) + # if image.shape[0] > 1: + # raise ValueError( + # f"Image shape {image.shape} might have several channels" + # ) image = np.squeeze(image, axis=0) if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) @@ -200,7 +200,6 @@ def __init__( self.config = config self.log = log - # TODO(cyril) : add progress bar into log ? or do it in inference def _run_crf_job(self): """Runs the CRF post-processing step for the W-Net.""" if not CRF_INSTALLED: @@ -211,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + # labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index eb660820..73b0ba5c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,3 +1,4 @@ +import abc from dataclasses import dataclass from functools import partial from typing import List @@ -79,8 +80,32 @@ def __init__( ) self.counters.append(getattr(self, widget)) + @abc.abstractmethod def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") + raise NotImplementedError() + + def _make_list_from_channels( + self, image + ): # TODO(cyril) : adapt to batch dimension + if len(image.shape) > 4: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" + ) + if len(image.shape) == 4: + image = np.squeeze(image) + if len(image.shape) == 4: + return [im for im in image] + elif len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) + else: + return [image] + + def run_method_on_channels(self, image): + image_list = self._make_list_from_channels(image) # FIXME rename + result = np.array([self.run_method(im) for im in image_list]) + return result.squeeze() @dataclass @@ -582,7 +607,7 @@ def run_method(self, volume): """ method = self.methods[self.method_choice.currentText()] - return method.run_method(volume) + return method.run_method_on_channels(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 7c52cd26..66c7bd9a 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -644,17 +644,11 @@ def instance_seg( self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - - if len(semantic_labels.shape) == 4: - instance_labels = np.array( - [method.run_method(ch) for ch in semantic_labels] - ) - self.log(f"DEBUG instance results shape : {instance_labels.shape}") - else: - instance_labels = method.run_method(image=semantic_labels) + instance_labels = method.run_method_on_channels(semantic_labels) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") if self.config.filetype == "": - filetype = "" + filetype = ".tif" else: filetype = "_" + self.config.filetype @@ -854,7 +848,8 @@ def inference(self): weights = str( PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - model.load_state_dict( + + model.load_state_dict( # note that this is redefined in WNet_ torch.load( weights, map_location=self.config.device, @@ -1231,7 +1226,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if self.config.do_augmentation: train_transforms = ( Compose( # TODO : figure out which ones and values ? @@ -1261,31 +1255,32 @@ def train(self): EnsureTyped(keys=["image", "label"]), ] ) + # self.log("Loading dataset...\n") def get_loader_func(num_samples): - return Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) ), - EnsureTyped(keys=["image", "label"]), - ] - ) + ), + EnsureTyped(keys=["image", "label"]), + ] + ) if do_sampling: # if there is only one volume, split samples @@ -1309,7 +1304,6 @@ def get_loader_func(num_samples): sample_loader_train = get_loader_func(num_train_samples) sample_loader_eval = get_loader_func(num_val_samples) - logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" diff --git a/napari_cellseg3d/code_models/models/TEMPLATE_model.py b/napari_cellseg3d/code_models/models/TEMPLATE_model.py new file mode 100644 index 00000000..f68e5f4f --- /dev/null +++ b/napari_cellseg3d/code_models/models/TEMPLATE_model.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class ModelTemplate_(ABC): + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 05819e22..484890d1 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -9,12 +9,19 @@ class SwinUNETR_(SwinUNETR): use_default_training = True weights_file = "Swin64_best_metric.pth" - def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + def __init__( + self, + in_channels=1, + out_channels=1, + input_img_size=128, + use_checkpoint=True, + **kwargs, + ): try: super().__init__( input_img_size, - in_channels=1, - out_channels=1, + in_channels=in_channels, + out_channels=out_channels, feature_size=48, use_checkpoint=use_checkpoint, **kwargs, diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 4a9ff70d..86a1f7e6 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,5 +1,12 @@ +from typing import TypeVar + +from torch.nn import Module + +# local from napari_cellseg3d.code_models.models.wnet.model import WNet +T = TypeVar("T", bound="Module") + class WNet_(WNet): use_default_training = False @@ -20,6 +27,9 @@ def __init__( num_classes=num_classes, ) + def train(self: T, mode: bool = True) -> T: + raise NotImplementedError("Training not implemented for WNet") + def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. @@ -27,3 +37,12 @@ def forward(self, x): enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc + + def load_state_dict(self, state_dict, strict=False): + """Load the model state dict for inference, without the decoder weights.""" + encoder_checkpoint = state_dict.copy() + for k in state_dict.keys(): + if k.startswith("decoder"): + encoder_checkpoint.pop(k) + # print(encoder_checkpoint.keys()) + super().load_state_dict(encoder_checkpoint, strict=strict) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 44123e34..7d939c6a 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -363,7 +363,7 @@ def _start(self): elif self.folder_choice.isChecked(): if len(self.images_filepaths) != 0: images = [ - self.instance_widgets.run_method(imread(file)) + self.instance_widgets.run_method_on_channels(imread(file)) for file in self.images_filepaths ] utils.save_folder( From 9df47d72528580faf9df01451d20011a57b85d63 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 15:16:25 +0200 Subject: [PATCH 161/577] Update model_WNet.py --- napari_cellseg3d/code_models/models/model_WNet.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 86a1f7e6..f07ac517 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,12 +1,6 @@ -from typing import TypeVar - -from torch.nn import Module - # local from napari_cellseg3d.code_models.models.wnet.model import WNet -T = TypeVar("T", bound="Module") - class WNet_(WNet): use_default_training = False @@ -27,8 +21,8 @@ def __init__( num_classes=num_classes, ) - def train(self: T, mode: bool = True) -> T: - raise NotImplementedError("Training not implemented for WNet") + # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): """Forward ENCODER pass of the W-Net model. From eedd5847b691374a27a58941fbebb6e8a886d744 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 13 May 2023 10:29:39 +0200 Subject: [PATCH 162/577] Update model_VNet.py --- napari_cellseg3d/code_models/models/model_VNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 7aa6476e..41554e80 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=2, **kwargs): + def __init__(self, in_channels=1, out_channels=1, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs From b85874b56429bedc219d996c7caf710d1d2769fb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:51:02 +0200 Subject: [PATCH 163/577] Fixed folder creation when saving to folder --- napari_cellseg3d/code_models/crf.py | 2 +- napari_cellseg3d/code_plugins/plugin_convert.py | 10 +++++----- napari_cellseg3d/code_plugins/plugin_crf.py | 2 +- napari_cellseg3d/utils.py | 3 +++ 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 21caf35f..aa9cce75 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -210,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - # labels = correct_shape_for_crf(labels) + labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 7d939c6a..77aa9af6 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -46,7 +46,7 @@ def __init__(self, viewer: "napari.Viewer.viewer", parent=None): self.aniso_widgets = ui.AnisotropyWidgets(self, always_visible=True) self.start_btn = ui.Button("Start", self._start) - self.results_path = Path.home() / Path("cellseg3d/anisotropy") + self.results_path = str(Path.home() / Path("cellseg3d/anisotropy")) self.results_filewidget.text_field.setText(str(self.results_path)) self.results_filewidget.check_ready() @@ -76,7 +76,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -175,7 +175,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -342,7 +342,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -436,7 +436,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.binarize_counter.value() if self.layer_choice: diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 7ac605e9..d8407a0f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -238,7 +238,7 @@ def _start(self): self.result_layer = self.label_layer_loader.layer() self.result_name = self.label_layer_loader.layer_name() - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) image_list = [self.image_layer_loader.layer_data()] labels_list = [self.label_layer_loader.layer_data()] diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 8ca6f146..b29c5b50 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -131,6 +131,9 @@ def normalize_x(image): image = image / 127.5 - 1 return image +def mkdir_from_str(path: str, exist_ok=True, parents=True): + Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) + def normalize_y(image): """Normalizes the values of an image array to be between [0;1] rather than [0;255] From 7a9523dccbfd7982c364010b968a6ba48b15d9bd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:54:07 +0200 Subject: [PATCH 164/577] Fix check_ready for results filewidget --- napari_cellseg3d/interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 574ef23a..3001ab31 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -855,6 +855,9 @@ def __init__( self.build() self.check_ready() + if self._required: + self._text_field.textChanged.connect(self.check_ready) + def build(self): """Builds the layout of the widget""" add_widgets( @@ -913,7 +916,7 @@ def required(self, is_required): try: self.text_field.textChanged.disconnect(self.check_ready) except TypeError: - return + pass self.check_ready() self._required = is_required From e699af8b3fd7886882c9f771df0edd1988b5f610 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 11:28:33 +0200 Subject: [PATCH 165/577] Added remapping in WNet + ruff config --- .pre-commit-config.yaml | 3 + napari_cellseg3d/code_models/model_workers.py | 62 ++++++++----------- napari_cellseg3d/utils.py | 54 +++++++++++----- pyproject.toml | 36 ++++++++++- 4 files changed, 102 insertions(+), 53 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61ecaae5..f9fe2853 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,9 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + - id: check-toml # - repo: https://github.com/pycqa/isort # rev: 5.12.0 # hooks: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 66c7bd9a..4ce4d180 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -119,9 +119,9 @@ def show_progress(_, block_size, __): # count, block_size, total_size logger.info(message) return - with open(json_path) as f: + with Path.open(json_path) as f: neturls = json.load(f) - if model_name in neturls.keys(): + if model_name in neturls: url = neturls[model_name] response = urllib.request.urlopen(url) @@ -259,8 +259,7 @@ def create_inference_dict(images_filepaths): Returns: dict: list of image paths from loaded folder""" - data_dicts = [{"image": image_name} for image_name in images_filepaths] - return data_dicts + return [{"image": image_name} for image_name in images_filepaths] def set_download_log(self, widget): self.downloader.log_widget = widget @@ -304,10 +303,11 @@ def log_parameters(self): f"Thresholding is enabled at {config.post_process_config.thresholding.threshold_value}" ) - if config.sliding_window_config.is_enabled(): - status = "enabled" - else: - status = "disabled" + status = ( + "enabled" + if config.sliding_window_config.is_enabled() + else "disabled" + ) self.log(f"Window inference is {status}\n") if status == "enabled": @@ -471,10 +471,9 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - if self.config.keep_on_cpu: - dataset_device = "cpu" - else: - dataset_device = self.config.device + dataset_device = ( + "cpu" if self.config.keep_on_cpu else self.config.device + ) if self.config.sliding_window_config.is_enabled(): window_size = self.config.sliding_window_config.window_size @@ -491,6 +490,7 @@ def model_output( # outputs = model(inputs) def model_output_wrapper(inputs): + inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) @@ -508,7 +508,7 @@ def model_output_wrapper(inputs): progress=True, ) except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.debug("failed to run sliding window inference") self.raise_error(e, "Error during sliding window inference") logger.debug(f"Inference output shape: {outputs.shape}") @@ -519,11 +519,9 @@ def model_output_wrapper(inputs): if post_process: out = np.array(out).astype(np.float32) out = np.squeeze(out) - return out - else: - return out + return out except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Error during sliding window inference") # sys.stdout = old_stdout # sys.stderr = old_stderr @@ -634,8 +632,7 @@ def aniso_transform(self, image): padding_mode="empty", ) return anisotropic_transform(image[0]) - else: - return image + return image def instance_seg( self, semantic_labels, image_id=0, original_filename="layer" @@ -647,10 +644,11 @@ def instance_seg( instance_labels = method.run_method_on_channels(semantic_labels) self.log(f"DEBUG instance results shape : {instance_labels.shape}") - if self.config.filetype == "": - filetype = ".tif" - else: - filetype = "_" + self.config.filetype + filetype = ( + ".tif" + if self.config.filetype == "" + else "_" + self.config.filetype + ) instance_filepath = ( self.config.results_path @@ -732,9 +730,9 @@ def stats_csv(self, instance_labels): stats = [volume_stats(c) for c in instance_labels] else: stats = [volume_stats(instance_labels)] - return stats else: - return None + stats = None + return stats except ValueError as e: self.log(f"Error occurred during stats computing : {e}") return None @@ -758,10 +756,7 @@ def inference_on_layer(self, image, model, post_process_transforms): semantic_labels=out, from_layer=True ) - if self.config.use_crf: - crf_results = self.run_crf(image, out) - else: - crf_results = None + crf_results = self.run_crf(image, out) if self.config.use_crf else None return self.create_inference_result( semantic_labels=out, @@ -943,7 +938,7 @@ def inference(self): model.to("cpu") # self.quit() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Inference failed") self.quit() finally: @@ -1174,10 +1169,7 @@ def train(self): do_sampling = self.config.sampling - if do_sampling: - size = self.config.sample_size - else: - size = check + size = self.config.sample_size if do_sampling else check model = model_class( # FIXME check if correct input_img_size=utils.get_padding_dim(size), use_checkpoint=True @@ -1410,7 +1402,7 @@ def get_loader_func(num_samples): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") - logger.error(e, exc_info=True) + logger.exception(e) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index b29c5b50..7e7a5c23 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING, Union import napari import numpy as np @@ -9,6 +10,9 @@ from skimage.filters import gaussian from tifffile import imread, imwrite +if TYPE_CHECKING: + import torch + LOGGER = logging.getLogger(__name__) ############### # Global logging level setting @@ -128,8 +132,8 @@ def normalize_x(image): Returns: array: normalized value for the image """ - image = image / 127.5 - 1 - return image + return image / 127.5 - 1 + def mkdir_from_str(path: str, exist_ok=True, parents=True): Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) @@ -144,8 +148,7 @@ def normalize_y(image): Returns: array: normalized value for the image """ - image = image / 255 - return image + return image / 255 def sphericity_volume_area(volume, surface_area): @@ -199,10 +202,9 @@ def dice_coeff(y_true, y_pred): y_true_f = y_true.flatten() y_pred_f = y_pred.flatten() intersection = np.sum(y_true_f * y_pred_f) - score = (2.0 * intersection + smooth) / ( + return (2.0 * intersection + smooth) / ( np.sum(y_true_f) + np.sum(y_pred_f) + smooth ) - return score def correct_rotation(image): @@ -211,6 +213,27 @@ def correct_rotation(image): return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) +def normalize_max(image): + """Normalizes an image using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image.reshape(shape) + return image + + +def remap_image( + image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 +): + """Normalizes a numpy array or Tensor using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image * (new_max - new_min) + new_min + image = image.reshape(shape) + return image + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, @@ -226,9 +249,8 @@ def align_array_sizes(array_shape, target_shape): for i in range(len(target_shape)): if target_shape[i] != array_shape[i]: for j in range(len(array_shape)): - if array_shape[i] == target_shape[j]: - if j != i: - index_differences.append({"origin": i, "target": j}) + if array_shape[i] == target_shape[j] and j != i: + index_differences.append({"origin": i, "target": j}) # print(index_differences) if len(index_differences) == 0: @@ -277,10 +299,11 @@ def time_difference(time_start, time_finish, as_string=True): minutes = f"{int(minutes[0])}".zfill(2) seconds = f"{int(seconds[0])}".zfill(2) - if as_string: - return f"{hours}:{minutes}:{seconds}" - else: - return [hours, minutes, seconds] + return ( + f"{hours}:{minutes}:{seconds}" + if as_string + else [hours, minutes, seconds] + ) def get_padding_dim(image_shape, anisotropy_factor=None): @@ -446,6 +469,7 @@ def fill_list_in_between(lst, n, fill_value): for _j in range(n): new_list.append(fill_value) return new_list + return None # def check_zarr(project_path, ext): @@ -547,10 +571,8 @@ def load_images(dir_or_path, filetype="", as_folder: bool = False): "Loading as folder not implemented yet. Use napari to load as folder" ) # images_original = dask_imread(filename_pattern_original) - else: - images_original = imread(filename_pattern_original) # tifffile imread - return images_original + return imread(filename_pattern_original) # tifffile imread # def load_predicted_masks(mito_mask_dir, er_mask_dir, filetype): diff --git a/pyproject.toml b/pyproject.toml index 547047ba..74f86aa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,11 +46,43 @@ where = ["."] [tool.ruff] select = [ "E", "F", "W", - "I", + "A", "B", + "G", + "I", + "PT", + "PTH", + "RET", + "SIM", + "TCH", + "NPY", ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) -ignore = ["E501", "E741"] +# and 'G004' (do not use f-strings in logging) +ignore = ["E501", "E741", "G004"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] [tool.black] line-length = 79 From a5498f3fe24532056eb3233ff61b7a7f89c5b563 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 13:21:06 +0200 Subject: [PATCH 166/577] Run new hooks --- napari_cellseg3d/_tests/test_models.py | 13 +- .../_tests/test_weight_download.py | 2 +- napari_cellseg3d/code_models/crf.py | 25 ++-- ...stance_seg.py => instance_segmentation.py} | 19 ++- .../code_models/model_framework.py | 20 +-- .../code_models/models/model_SwinUNetR.py | 2 +- .../code_models/models/model_TRAILMAP_MS.py | 2 +- .../code_models/models/model_WNet.py | 8 +- .../code_models/models/unet/buildingblocks.py | 8 +- .../code_models/models/unet/model.py | 2 +- .../code_models/models/wnet/soft_Ncuts.py | 4 +- .../{model_workers.py => workers.py} | 2 +- .../code_plugins/plugin_convert.py | 127 +++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 11 +- .../code_plugins/plugin_metrics.py | 12 +- .../code_plugins/plugin_model_inference.py | 19 ++- .../code_plugins/plugin_model_training.py | 18 +-- .../code_plugins/plugin_review.py | 11 +- .../code_plugins/plugin_review_dock.py | 10 +- napari_cellseg3d/config.py | 8 +- .../dev_scripts/artefact_labeling.py | 16 +-- .../dev_scripts/correct_labels.py | 7 +- napari_cellseg3d/interface.py | 60 ++++----- pyproject.toml | 2 + 24 files changed, 205 insertions(+), 203 deletions(-) rename napari_cellseg3d/code_models/{model_instance_seg.py => instance_segmentation.py} (99%) rename napari_cellseg3d/code_models/{model_workers.py => workers.py} (99%) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35af8c76..35174b85 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,20 +1,23 @@ import numpy as np import torch +from numpy.random import PCG64, Generator from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +rand_gen = Generator(PCG64(12345)) + def test_correct_shape_for_crf(): - test = np.random.rand(1, 1, 8, 8, 8) + test = rand_gen.random(size=(1, 1, 8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) - test = np.random.rand(8, 8, 8) + test = rand_gen.random(size=(8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) def test_model_list(): - for model_name in MODEL_LIST.keys(): + for model_name in MODEL_LIST: # if model_name=="test": # continue dims = 128 @@ -46,8 +49,8 @@ def test_soft_ncuts_loss(): def test_crf(qtbot): dims = 8 - mock_image = np.random.rand(1, dims, dims, dims) - mock_label = np.random.rand(2, dims, dims, dims) + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index 972550e9..a00ab1de 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( PRETRAINED_WEIGHTS_DIR, WeightsDownloader, ) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index aa9cce75..8c311059 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,17 +54,15 @@ def correct_shape_for_crf(image, desired_dims=4): - if len(image.shape) == desired_dims: - return image if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( # f"Image shape {image.shape} might have several channels" # ) image = np.squeeze(image, axis=0) - if len(image.shape) < desired_dims: + elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) - return correct_shape_for_crf(image) + return image def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): @@ -185,8 +183,8 @@ class CRFWorker(GeneratorWorker): def __init__( self, - images_list, - labels_list, + images_list: list, + labels_list: list, config: CRFConfig = None, log=None, ): @@ -205,16 +203,19 @@ def _run_crf_job(self): if not CRF_INSTALLED: raise ImportError("pydensecrf is not installed.") - for image, labels in zip(self.images, self.labels): - if image.shape[-3:] != labels.shape[-3:]: + if len(self.images) != len(self.labels): + raise ValueError("Number of images and labels must be the same.") + + for i in range(len(self.images)): + if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + im = correct_shape_for_crf(self.labels[i]) + prob = correct_shape_for_crf(self.labels[i]) yield crf( - image, - labels, + im, + prob, self.config.sa, self.config.sb, self.config.sg, diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/instance_segmentation.py similarity index 99% rename from napari_cellseg3d/code_models/model_instance_seg.py rename to napari_cellseg3d/code_models/instance_segmentation.py index 73b0ba5c..59909c6e 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -91,16 +91,16 @@ def _make_list_from_channels( raise ValueError( f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" ) + if len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) if len(image.shape) == 4: image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - elif len(image.shape) < 2: - raise ValueError( - f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" - ) - else: return [image] + return None def run_method_on_channels(self, image): image_list = self._make_list_from_channels(image) # FIXME rename @@ -313,12 +313,10 @@ def to_instance(image, is_file_path=False): image = [imread(image)] # image = image.compute() - result = binary_watershed( + return binary_watershed( image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 ) # FIXME add params from utils plugin - return result - def to_semantic(image, is_file_path=False): """Converts a **ground-truth** label to semantic (binary 0/1) labels. @@ -335,8 +333,7 @@ def to_semantic(image, is_file_path=False): # image = image.compute() image[image >= 1] = 1 - result = image.astype(np.uint16) - return result + return image.astype(np.uint16) def volume_stats(volume_image): @@ -588,7 +585,7 @@ def _build(self): self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): + for name in self.instance_widgets: if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 1c3abe3f..60644916 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,8 +1,11 @@ from pathlib import Path +from typing import TYPE_CHECKING -import napari import torch +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QProgressBar, QSizePolicy @@ -126,7 +129,7 @@ def save_log(self): path = self.results_path if len(log) != 0: - with open( + with Path.open( path + f"/Log_report_{utils.get_date_time()}.txt", "x", ) as f: @@ -152,8 +155,8 @@ def save_log_to_path(self, path): ) if len(log) != 0: - with open( - path, + with Path.open( + Path(path), "x", ) as f: f.write(log) @@ -282,11 +285,10 @@ def _load_weights_path(self): ) if file[0] == self._default_weights_folder: return - if file is not None: - if file[0] != "": - self.weights_config.path = file[0] - self.weights_filewidget.text_field.setText(file[0]) - self._default_weights_folder = str(Path(file[0]).parent) + if file is not None and file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) @staticmethod def get_device(show=True): diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 484890d1..2d7b5ef6 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -27,7 +27,7 @@ def __init__( **kwargs, ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( input_img_size, in_channels=1, diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 1123173a..baf8635d 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -16,7 +16,7 @@ def __init__(self, in_channels=1, out_channels=1, **kwargs): in_channels=in_channels, out_channels=out_channels, **kwargs ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( in_channels=in_channels, out_channels=out_channels ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index f07ac517..7235bd61 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -28,14 +28,14 @@ def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. """ - enc = self.forward_encoder(x) - # dec = self.forward_decoder(enc) - return enc + return self.forward_encoder(x) + # enc = self.forward_encoder(x) + # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" encoder_checkpoint = state_dict.copy() - for k in state_dict.keys(): + for k in state_dict: if k.startswith("decoder"): encoder_checkpoint.pop(k) # print(encoder_checkpoint.keys()) diff --git a/napari_cellseg3d/code_models/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py index 4cdc0a43..ce7d378f 100644 --- a/napari_cellseg3d/code_models/models/unet/buildingblocks.py +++ b/napari_cellseg3d/code_models/models/unet/buildingblocks.py @@ -64,10 +64,7 @@ def create_conv( ) elif char == "g": is_before_conv = i < order.index("c") - if is_before_conv: - num_channels = in_channels - else: - num_channels = out_channels + num_channels = in_channels if is_before_conv else out_channels # use only one group if the given number of groups is greater than the number of channels if num_channels < num_groups: @@ -425,8 +422,7 @@ def forward(self, encoder_features, x): def _joining(encoder_features, x, concat): if concat: return torch.cat((encoder_features, x), dim=1) - else: - return encoder_features + x + return encoder_features + x def create_encoders( diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index 9614a555..9591d054 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -64,7 +64,7 @@ def __init__( f_maps, num_levels=num_levels ) - assert isinstance(f_maps, list) or isinstance(f_maps, tuple) + assert isinstance(f_maps, (list, tuple)) assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" # create encoder path diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 4e84579f..938292c2 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -206,6 +206,7 @@ def forward(self, labels, inputs): return torch.add(torch.neg(loss), K) """ + return None def gaussian_kernel(self, radius, sigma): """Computes the Gaussian kernel. @@ -348,5 +349,4 @@ def get_weights(self, inputs): 1, 1, self.W_X.shape[0], self.W_X.shape[1] ) # (1, 1, H*W*D, H*W*D) - W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) - return W + return torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/workers.py similarity index 99% rename from napari_cellseg3d/code_models/model_workers.py rename to napari_cellseg3d/code_models/workers.py index 4ce4d180..c1ed62fd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -54,7 +54,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.crf import crf_with_config -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( ImageStats, volume_stats, ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 77aa9af6..4357e51e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -7,7 +7,7 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceWidgets, clear_small_objects, threshold, @@ -98,18 +98,19 @@ def _start(self): f"isotropic_{layer.name}", ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -193,18 +194,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -274,18 +276,19 @@ def _start(self): utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): @@ -360,18 +363,19 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.instance_widgets.run_method_on_channels(imread(file)) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.instance_widgets.run_method_on_channels(imread(file)) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -454,18 +458,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index e2189a15..a27b4baa 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -157,8 +157,10 @@ def _build(self): dim_group_l.addWidget(self.aniso_widgets) [ dim_group_l.addWidget(widget, alignment=ui.ABS_AL) - for list in zip(self.crop_size_labels, self.crop_size_widgets) - for widget in list + for widget_list in zip( + self.crop_size_labels, self.crop_size_widgets + ) + for widget in widget_list ] dim_group_w.setLayout(dim_group_l) layout.addWidget(dim_group_w) @@ -237,10 +239,7 @@ def quicksave(self): def _check_ready(self): if self.image_layer_loader.layer_data() is not None: if self.crop_second_image: - if self.label_layer_loader.layer_data() is not None: - return True - else: - return False + return self.label_layer_loader.layer_data() is not None return True return False diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 114025f6..2a6e713c 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -1,5 +1,6 @@ +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt -import napari import numpy as np from matplotlib.backends.backend_qt5agg import ( FigureCanvasQTAgg as FigureCanvas, @@ -8,9 +9,12 @@ from monai.transforms import SpatialPad, ToTensor from tifffile import imread +if TYPE_CHECKING: + import napari + from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.instance_segmentation import to_semantic from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder DEFAULT_THRESHOLD = 0.5 @@ -187,11 +191,11 @@ def compute_dice(self): self.canvas = ( None # kind of terrible way to stack plots... but it works. ) - id = 0 + image_id = 0 for ground_path, pred_path in zip( self.images_filepaths, self.labels_filepaths ): - id += 1 + image_id += 1 ground = imread(ground_path) pred = imread(pred_path) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 78cf4438..043c5947 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,18 +1,21 @@ from functools import partial +from typing import TYPE_CHECKING -import napari import numpy as np import pandas as pd +if TYPE_CHECKING: + import napari + # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceMethod, InstanceWidgets, ) -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.workers import ( InferenceResult, InferenceWorker, ) @@ -289,9 +292,11 @@ def check_ready(self): if self.layer_choice.isChecked(): if self.image_layer_loader.layer_data() is not None: return True - elif self.folder_choice.isChecked(): - if self.image_filewidget.check_ready(): - return True + elif ( + self.folder_choice.isChecked() + and self.image_filewidget.check_ready() + ): + return True return False def _toggle_display_model_input_size(self): diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 86d1d317..35a16799 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,9 +1,9 @@ import shutil from functools import partial from pathlib import Path +from typing import TYPE_CHECKING import matplotlib.pyplot as plt -import napari import numpy as np import pandas as pd import torch @@ -12,6 +12,9 @@ ) from matplotlib.figure import Figure +if TYPE_CHECKING: + import napari + # MONAI from monai.losses import ( DiceCELoss, @@ -29,7 +32,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( TrainingReport, TrainingWorker, ) @@ -414,11 +417,10 @@ def check_ready(self): * False and displays a warning if not """ - if self.images_filepaths != [] and self.labels_filepaths != []: - return True - else: + if self.images_filepaths == [] and self.labels_filepaths != []: logger.warning("Image and label paths are not correctly set") return False + return True def _build(self): """Builds the layout of the widget and creates the following tabs and prompts: @@ -999,7 +1001,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1131,7 +1133,7 @@ def update_loss_plot(self, loss, metric): epoch = len(loss) if epoch < self.worker_config.validation_interval * 2: return - elif epoch == self.worker_config.validation_interval * 2: + if epoch == self.worker_config.validation_interval * 2: bckgrd_color = (0, 0, 0, 0) # '#262930' with plt.style.context("dark_background"): self.canvas = FigureCanvas(Figure(figsize=(10, 1.5))) @@ -1167,7 +1169,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 235595e4..dd98bcd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -178,11 +178,10 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None: - if cfg.image.shape != cfg.labels.shape: - logger.warning( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: + logger.warning( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): @@ -400,7 +399,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index 6cee7c94..f634d117 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,9 +1,12 @@ from datetime import datetime, timedelta from pathlib import Path +from typing import TYPE_CHECKING -import napari import pandas as pd +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QVBoxLayout, QWidget @@ -213,10 +216,7 @@ def create_csv(self, label_dir, model_type, filename=None): ) else: # print(self.image_dims[0]) - if self.filename is not None: - filename = self.filename - else: - filename = "image" + filename = self.filename if self.filename is not None else "image" labels = [str(filename) for i in range(self.image_dims[0])] df = pd.DataFrame( diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 7250fe78..af42d779 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,7 +6,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.instance_segmentation import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ @@ -89,9 +89,9 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info( - "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) - ) + logger.info("Model list :") + for model_name in MODEL_LIST: + logger.info(f" * {model_name}") return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b4712aec..93746eb6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ -import os +import os # TODO(cyril): remove os +from pathlib import Path import napari import numpy as np @@ -6,7 +7,7 @@ from skimage.filters import threshold_otsu from tifffile import imread, imwrite -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -289,18 +290,13 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): ndarray Label image with artefacts labelled and small artefacts removed. """ - if not is_labeled: - # find all the connected components in the artefacts image - labels = ndimage.label(artefacts)[0] - else: - labels = artefacts + labels = ndimage.label(artefacts)[0] if not is_labeled else artefacts # remove the small components labels_i, counts = np.unique(labels, return_counts=True) labels_i = labels_i[counts > min_size] labels_i = labels_i[labels_i > 0] - artefacts = np.where(np.isin(labels, labels_i), labels, 0) - return artefacts + return np.where(np.isin(labels, labels_i), labels, 0) def create_artefact_labels( @@ -388,7 +384,7 @@ def create_artefact_labels_from_folder( path_labels.sort() path_images.sort() # create the output folder - os.makedirs(path + "/artefact_neurons", exist_ok=True) + Path().mkdir(path + "/artefact_neurons", exist_ok=True) # create the artefact labels for i in range(len(path_images)): print(path_labels[i]) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 9862c3fa..4a7363b2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -12,7 +12,7 @@ from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) @@ -228,10 +228,7 @@ def relabel( print("these labels will be added") if test: viewer.close() - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer + viewer = napari.view_image(image) if viewer is None else viewer if not test: viewer.add_labels(artefact_copy, name="labels added") napari.run() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 3001ab31..061f4d1d 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,3 +1,4 @@ +import contextlib import threading from functools import partial from typing import List, Optional @@ -105,12 +106,12 @@ def __call__(cls, *args, **kwargs): ################## -def handle_adjust_errors(widget, type, context, msg: str): +def handle_adjust_errors(widget, warning_type, context, msg: str): """Qt message handler that attempts to react to errors when setting the window size and resizes the main window""" pass # head = msg.split(": ")[0] - # if type == QtWarningMsg and head == "QWindowsWindow::setGeometry": + # if warning_type == QtWarningMsg and head == "QWindowsWindow::setGeometry": # logger.warning( # f"Qt resize error : {msg}\nhas been handled by attempting to resize the window" # ) @@ -191,7 +192,7 @@ def show_utils_menu(self, widget, event): menu.setStyleSheet(f"background-color: {napari_grey}; color: white;") actions = [] - for title in UTILITIES_WIDGETS.keys(): + for title in UTILITIES_WIDGETS: a = menu.addAction(f"Utilities : {title}") actions.append(a) @@ -333,8 +334,7 @@ def toggle_visibility(checkbox, widget): def add_label(widget, label, label_before=True, horizontal=True): if label_before: return combine_blocks(widget, label, horizontal=horizontal) - else: - return combine_blocks(label, widget, horizontal=horizontal) + return combine_blocks(label, widget, horizontal=horizontal) class ContainerWidget(QWidget): @@ -736,8 +736,7 @@ def anisotropy_zoom_factor(aniso_res): """ base = min(aniso_res) - zoom_factors = [base / res for res in aniso_res] - return zoom_factors + return [base / res for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" @@ -797,8 +796,8 @@ def _remove_layer(self, event): index = self.layer_list.findText(removed_layer.name) self.layer_list.removeItem(index) - def set_layer_type(self, type): # no @property due to Qt constraint - self.layer_type = type + def set_layer_type(self, layer_type): # no @property due to Qt constraint + self.layer_type = layer_type [self.layer_list.removeItem(i) for i in range(self.layer_list.count())] self._check_for_layers() @@ -811,7 +810,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: logger.warning("Please select a valid layer !") - return + return None return self.layer().data @@ -899,9 +898,8 @@ def check_ready(self): self.update_field_color("indianred") self.text_field.setToolTip("Mandatory field !") return False - else: - self.update_field_color(f"{napari_param_darkgrey}") - return True + self.update_field_color(f"{napari_param_darkgrey}") + return True @property def required(self): @@ -913,10 +911,9 @@ def required(self, is_required): if is_required: self.text_field.textChanged.connect(self.check_ready) else: - try: + with contextlib.suppress(TypeError): self.text_field.textChanged.disconnect(self.check_ready) - except TypeError: - pass + self.check_ready() self._required = is_required @@ -1003,22 +1000,22 @@ def make_scrollable( def set_spinbox( box, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, fixed: Optional[bool] = True, ): """Args: box : QSpinBox or QDoubleSpinBox - min : minimum value, defaults to 0 - max : maximum value, defaults to 10 + min_value : minimum value, defaults to 0 + max_value : maximum value, defaults to 10 default : default value, defaults to 0 step : step value, defaults to 1 fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" - box.setMinimum(min) - box.setMaximum(max) + box.setMinimum(min_value) + box.setMaximum(max_value) box.setSingleStep(step) box.setValue(default) @@ -1029,8 +1026,8 @@ def set_spinbox( def make_n_spinboxes( class_, n: int = 2, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, parent: Optional[QWidget] = None, @@ -1041,8 +1038,8 @@ def make_n_spinboxes( Args: class_ : QSpinBox or QDoubleSpinbox n (int): number of increment counters to create - min (Optional[int]): minimum value, defaults to 0 - max (Optional[int]): maximum value, defaults to 10 + min_value (Optional[int]): minimum value, defaults to 0 + max_value (Optional[int]): maximum value, defaults to 10 default (Optional[int]): default value, defaults to 0 step (Optional[int]): step value, defaults to 1 parent: parent widget, defaults to None @@ -1053,7 +1050,7 @@ def make_n_spinboxes( boxes = [] for _i in range(n): - box = class_(min, max, default, step, parent, fixed) + box = class_(min_value, max_value, default, step, parent, fixed) boxes.append(box) return boxes @@ -1226,10 +1223,9 @@ def open_file_dialog( default_path = utils.parse_default_path(possible_paths) - f_name = QFileDialog.getOpenFileName( + return QFileDialog.getOpenFileName( widget, "Choose file", default_path, filetype ) - return f_name def open_folder_dialog( @@ -1239,10 +1235,9 @@ def open_folder_dialog( default_path = utils.parse_default_path(possible_paths) logger.info(f"Default : {default_path}") - filenames = QFileDialog.getExistingDirectory( + return QFileDialog.getExistingDirectory( widget, "Open directory", default_path + "/.." ) - return filenames def make_label(name, parent=None): # TODO update to child class @@ -1259,12 +1254,11 @@ def make_label(name, parent=None): # TODO update to child class label = QLabel(name, parent) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label else: label = QLabel(name) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label + return label def make_group(title, l=7, t=20, r=7, b=11, parent=None): diff --git a/pyproject.toml b/pyproject.toml index 74f86aa2..082176b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ exclude = [ "dist", "node_modules", "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", ] [tool.black] From 18083f75e17881e46d1abce78fe72a0aff027000 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:06:24 +0200 Subject: [PATCH 167/577] Small docs update --- docs/index.rst | 4 +- docs/res/code/instance_segmentation.rst | 53 +++++++++++++++++++ docs/res/code/model_instance_seg.rst | 53 ------------------- docs/res/code/plugin_convert.rst | 15 ------ docs/res/code/utils.rst | 4 -- .../code/{model_workers.rst => workers.rst} | 8 +-- docs/res/guides/custom_model_template.rst | 28 +++++++++- docs/res/guides/detailed_walkthrough.rst | 4 +- docs/res/guides/inference_module_guide.rst | 2 +- docs/res/guides/training_module_guide.rst | 2 +- napari_cellseg3d/code_models/workers.py | 28 +++++----- 11 files changed, 105 insertions(+), 96 deletions(-) create mode 100644 docs/res/code/instance_segmentation.rst delete mode 100644 docs/res/code/model_instance_seg.rst rename docs/res/code/{model_workers.rst => workers.rst} (78%) diff --git a/docs/index.rst b/docs/index.rst index 7e809fbe..46c57c08 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,8 +39,8 @@ Welcome to napari-cellseg3d's documentation! res/code/plugin_convert res/code/plugin_metrics res/code/model_framework - res/code/model_workers - res/code/model_instance_seg + res/code/workers + res/code/instance_segmentation res/code/plugin_model_inference res/code/plugin_model_training res/code/utils diff --git a/docs/res/code/instance_segmentation.rst b/docs/res/code/instance_segmentation.rst new file mode 100644 index 00000000..143560c4 --- /dev/null +++ b/docs/res/code/instance_segmentation.rst @@ -0,0 +1,53 @@ +instance_segmentation.py +=========================================== + +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::VoronoiOtsu + :members: __init__ + + +Functions +------------- + +binary_connected +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_connected + +binary_watershed +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_watershed + +volume_stats +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::volume_stats + +clear_small_objects +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::clear_small_objects + +to_instance +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_instance + +to_semantic +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_semantic diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst deleted file mode 100644 index 3b323173..00000000 --- a/docs/res/code/model_instance_seg.rst +++ /dev/null @@ -1,53 +0,0 @@ -model_instance_seg.py -=========================================== - -Classes -------------- - -InstanceMethod -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod - :members: __init__ - -ConnectedComponents -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents - :members: __init__ - -Watershed -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed - :members: __init__ - -VoronoiOtsu -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu - :members: __init__ - - -Functions -------------- - -binary_connected -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_connected - -binary_watershed -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_watershed - -volume_stats -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::volume_stats - -clear_small_objects -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::clear_small_objects - -to_instance -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_instance - -to_semantic -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_semantic diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 03944510..25006d0f 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -28,18 +28,3 @@ ThresholdUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ThresholdUtils :members: __init__ - -Functions ------------------------------------ - -save_folder -***************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_folder - -save_layer -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_layer - -show_result -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index e90ee7e0..d9fdcfa2 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -62,7 +62,3 @@ denormalize_y load_images ************************************** .. autofunction:: napari_cellseg3d.utils::load_images - -format_Warning -************************************** -.. autofunction:: napari_cellseg3d.utils::format_Warning diff --git a/docs/res/code/model_workers.rst b/docs/res/code/workers.rst similarity index 78% rename from docs/res/code/model_workers.rst rename to docs/res/code/workers.rst index 85f8da29..1f5167ad 100644 --- a/docs/res/code/model_workers.rst +++ b/docs/res/code/workers.rst @@ -1,4 +1,4 @@ -model_workers.py +workers.py =========================================== @@ -10,7 +10,7 @@ Class : LogSignal Attributes ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::LogSignal +.. autoclass:: napari_cellseg3d.code_models.workers::LogSignal :members: log_signal :noindex: @@ -24,7 +24,7 @@ Class : InferenceWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::InferenceWorker +.. autoclass:: napari_cellseg3d.code_models.workers::InferenceWorker :members: __init__, log, create_inference_dict, inference :noindex: @@ -39,6 +39,6 @@ Class : TrainingWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::TrainingWorker +.. autoclass:: napari_cellseg3d.code_models.workers::TrainingWorker :members: __init__, log, train :noindex: diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index 218795b1..a70df29b 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -3,9 +3,33 @@ Advanced : Declaring a custom model ============================================= -To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder: +.. warning:: + **WIP** : Adding new models is still a work in progress and will likely not work simply by adding the model in the plugin. + + Please `file an issue`_ if you would like to add a custom model and we will help you get it working. + +To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder:: + + class ModelTemplate_(ABC): # replace ABC with your PyTorch model class name + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) # download URL goes in pretrained_models.json + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass + .. note:: **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute -:: +.. _file an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues diff --git a/docs/res/guides/detailed_walkthrough.rst b/docs/res/guides/detailed_walkthrough.rst index 407893c2..3d06d998 100644 --- a/docs/res/guides/detailed_walkthrough.rst +++ b/docs/res/guides/detailed_walkthrough.rst @@ -1,6 +1,6 @@ .. _detailed_walkthrough: -Detailed walkthrough +Detailed walkthrough - Supervised learning =================================== The following guide will show you how to use the plugin's workflow, starting from human-labeled annotation volume, to running inference on novel volumes. @@ -109,7 +109,7 @@ of two no matter the size you choose. For optimal performance, make sure to use a power of two still, such as 64 or 120. .. important:: - Using a too large value for the size will cause memory issues. If this happens, restart napari (better handling for these situations might be added in the future). + Using a too large value for the size will cause memory issues. If this happens, restart the worker with smaller volumes. You also have the option to use data augmentation, which can improve performance and generalization. In most cases this should left enabled. diff --git a/docs/res/guides/inference_module_guide.rst b/docs/res/guides/inference_module_guide.rst index 00e67078..373e9d0d 100644 --- a/docs/res/guides/inference_module_guide.rst +++ b/docs/res/guides/inference_module_guide.rst @@ -132,4 +132,4 @@ Source code -------------------------------- * :doc:`../code/plugin_model_inference` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index 05ce69be..1038dc6d 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -128,4 +128,4 @@ Source code -------------------------------- * :doc:`../code/plugin_model_training` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c1ed62fd..e2e21363 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -61,16 +61,6 @@ logger = utils.LOGGER -""" -Writing something to log messages from outside the main thread is rather problematic (plenty of silent crashes...) -so instead, following the instructions in the guides below to have a worker with custom signals, I implemented -a custom worker function.""" - -# FutureReference(): -# https://python-forum.io/thread-31349.html -# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ -# https://napari-staging-site.github.io/guides/stable/threading.html - PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) @@ -174,12 +164,23 @@ def safe_extract( ) +""" +Writing something to log messages from outside the main thread needs specific care, +Following the instructions in the guides below to have a worker with custom signals, +a custom worker function was implemented. +""" + +# https://python-forum.io/thread-31349.html +# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ +# https://napari-staging-site.github.io/guides/stable/threading.html + + class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_ + Separate from Worker instances as indicated `on this post`_ - .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect """ # TODO link ? log_signal = Signal(str) @@ -196,6 +197,9 @@ def __init__(self): super().__init__() +# TODO(cyril): move inference and training workers to separate files + + @dataclass class InferenceResult: """Class to record results of a segmentation job""" From b1a7f2c32d03b362789f32780858ad6e4f7544e5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:24:43 +0200 Subject: [PATCH 168/577] Testing fix --- napari_cellseg3d/code_models/instance_segmentation.py | 5 ++--- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 59909c6e..d1506f3d 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -99,11 +99,10 @@ def _make_list_from_channels( image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - return [image] - return None + return [image] def run_method_on_channels(self, image): - image_list = self._make_list_from_channels(image) # FIXME rename + image_list = self._make_list_from_channels(image) result = np.array([self.run_method(im) for im in image_list]) return result.squeeze() diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 7235bd61..cb5ef6d8 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,7 @@ def __init__( num_classes=num_classes, ) - # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): From 1126c79d9637a9e6695b5cfe92ed50e138efab68 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:59:05 +0200 Subject: [PATCH 169/577] Fixed multithread testing (locally) --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/_tests/test_models.py | 14 +- .../_tests/test_plugin_inference.py | 29 ++-- napari_cellseg3d/_tests/test_training.py | 27 ++-- .../code_plugins/plugin_model_inference.py | 127 +++++++++--------- .../code_plugins/plugin_model_training.py | 108 ++++++++------- 6 files changed, 158 insertions(+), 148 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 88a67ae2..fa6905d5 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -9,6 +9,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35174b85..4852f651 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -52,7 +52,7 @@ def test_crf(qtbot): mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 - crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) + crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) @@ -60,20 +60,20 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=True ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() mock_image = mock_image[0] mock_label = mock_label[0] crf = CRFWorker(mock_image, mock_label) - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=False ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 3dafeabc..d1264218 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,10 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_inference(make_napari_viewer, qtbot): @@ -29,14 +28,16 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.setCurrentIndex(-1) - - # widget.start() # takes too long on Github Actions - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: - # blocker.connect(widget.worker.errored) - - #### assert len(viewer.layers) == 2 + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.setCurrentIndex(-1) + + widget.worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(widget.config) + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() # takes too long on Github Actions + assert widget.worker is not None + # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 921a6d26..4d558363 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,10 +2,9 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_training(make_napari_viewer, qtbot): @@ -33,15 +32,19 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) - - # widget.start() - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds. - # blocker.connect(widget.worker.errored) + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + + worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(worker_config) + + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() + assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 043c5947..3dd843d4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -551,66 +551,7 @@ def start(self): self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - self.model_info = config.ModelInfo( - name=self.model_choice.currentText(), - model_input_size=self.model_input_size.value(), - ) - - self.weights_config.custom = self.custom_weights_choice.isChecked() - - save_path = self.results_filewidget.text_field.text() - if not self._check_results_path(save_path): - msg = f"ERROR: please set valid results path. Current path is {save_path}" - self.log.print_and_log(msg) - logger.warning(msg) - else: - if self.results_path is None: - self.results_path = save_path - - zoom_config = config.Zoom( - enabled=self.anisotropy_wdgt.enabled(), - zoom_values=self.anisotropy_wdgt.scaling_xyz(), - ) - thresholding_config = config.Thresholding( - enabled=self.thresholding_checkbox.isChecked(), - threshold_value=self.thresholding_slider.slider_value, - ) - - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[ - self.instance_widgets.method_choice.currentText() - ], - ) - - self.post_process_config = config.PostProcessConfig( - zoom=zoom_config, - thresholding=thresholding_config, - instance=self.instance_config, - ) - - if self.window_infer_box.isChecked(): - size = int(self.window_size_choice.currentText()) - window_config = config.SlidingWindowConfig( - window_size=size, - window_overlap=self.window_overlap_slider.slider_value, - ) - else: - window_config = config.SlidingWindowConfig() - - self.worker_config = config.InferenceWorkerConfig( - device=self.get_device(), - model_info=self.model_info, - weights_config=self.weights_config, - results_path=self.results_path, - filetype=self.filetype_choice.currentText(), - keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), - compute_stats=self.save_stats_to_csv_box.isChecked(), - post_process_config=self.post_process_config, - sliding_window_config=window_config, - use_crf=self.use_crf.isChecked(), - crf_config=self.crf_widgets.make_config(), - ) + self._set_worker_config() ##################### ##################### ##################### @@ -652,6 +593,72 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config(self, config: config.InferenceWorkerConfig): + return InferenceWorker(worker_config=config) + + def _set_worker_config(self) -> config.InferenceWorkerConfig: + self.model_info = config.ModelInfo( + name=self.model_choice.currentText(), + model_input_size=self.model_input_size.value(), + ) + + self.weights_config.custom = self.custom_weights_choice.isChecked() + + save_path = self.results_filewidget.text_field.text() + if not self._check_results_path(save_path): + msg = f"ERROR: please set valid results path. Current path is {save_path}" + self.log.print_and_log(msg) + logger.warning(msg) + else: + if self.results_path is None: + self.results_path = save_path + + zoom_config = config.Zoom( + enabled=self.anisotropy_wdgt.enabled(), + zoom_values=self.anisotropy_wdgt.scaling_xyz(), + ) + thresholding_config = config.Thresholding( + enabled=self.thresholding_checkbox.isChecked(), + threshold_value=self.thresholding_slider.slider_value, + ) + + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], + ) + + self.post_process_config = config.PostProcessConfig( + zoom=zoom_config, + thresholding=thresholding_config, + instance=self.instance_config, + ) + + if self.window_infer_box.isChecked(): + size = int(self.window_size_choice.currentText()) + window_config = config.SlidingWindowConfig( + window_size=size, + window_overlap=self.window_overlap_slider.slider_value, + ) + else: + window_config = config.SlidingWindowConfig() + + self.worker_config = config.InferenceWorkerConfig( + device=self.get_device(), + model_info=self.model_info, + weights_config=self.weights_config, + results_path=self.results_path, + filetype=self.filetype_choice.currentText(), + keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), + compute_stats=self.save_stats_to_csv_box.isChecked(), + post_process_config=self.post_process_config, + sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), + ) + return self.worker_config + def on_start(self): """Catches start signal from worker to call :py:func:`~display_status_report`""" self.display_status_report() diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 35a16799..e11eb3de 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -808,64 +808,10 @@ def start(self): self.data = None raise err - model_config = config.ModelInfo( - name=self.model_choice.currentText() - ) - - self.weights_config.path = self.weights_config.path - self.weights_config.custom = self.custom_weights_choice.isChecked() - self.weights_config.use_pretrained = ( - not self.use_transfer_choice.isChecked() - ) - - deterministic_config = config.DeterministicConfig( - enabled=self.use_deterministic_choice.isChecked(), - seed=self.box_seed.value(), - ) - - validation_percent = ( - self.validation_percent_choice.slider_value / 100 - ) - - results_path_folder = Path( - self.results_path - + f"/{model_config.name}_{utils.get_date_time()}" - ) - Path(results_path_folder).mkdir( - parents=True, exist_ok=False - ) # avoid overwrite where possible - - patch_size = [w.value() for w in self.patch_size_widgets] - - logger.debug("Loading config...") - self.worker_config = config.TrainingWorkerConfig( - device=self.get_device(), - model_info=model_config, - weights_info=self.weights_config, - train_data_dict=self.data, - validation_percent=validation_percent, - max_epochs=self.epoch_choice.value(), - loss_function=self.get_loss(self.loss_choice.currentText()), - learning_rate=float(self.learning_rate_choice.currentText()), - scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.slider_value, - validation_interval=self.val_interval_choice.value(), - batch_size=self.batch_choice.slider_value, - results_path_folder=str(results_path_folder), - sampling=self.patch_choice.isChecked(), - num_samples=self.sample_choice_slider.slider_value, - sample_size=patch_size, - do_augmentation=self.augment_choice.isChecked(), - deterministic_config=deterministic_config, - ) # TODO(cyril) continue to put params in config - self.config = config.TrainerConfig( save_as_zip=self.zip_choice.isChecked() ) - - self.log.print_and_log( - f"Saving results to : {results_path_folder}" - ) + self._set_worker_config() self.worker = TrainingWorker(config=self.worker_config) self.worker.set_download_log(self.log) @@ -895,6 +841,58 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config(self, config: config.TrainingWorkerConfig): + return TrainingWorker(config=config) + + def _set_worker_config(self) -> config.TrainingWorkerConfig: + model_config = config.ModelInfo(name=self.model_choice.currentText()) + + self.weights_config.path = self.weights_config.path + self.weights_config.custom = self.custom_weights_choice.isChecked() + self.weights_config.use_pretrained = ( + not self.use_transfer_choice.isChecked() + ) + + deterministic_config = config.DeterministicConfig( + enabled=self.use_deterministic_choice.isChecked(), + seed=self.box_seed.value(), + ) + + validation_percent = self.validation_percent_choice.slider_value / 100 + + results_path_folder = Path( + self.results_path + f"/{model_config.name}_{utils.get_date_time()}" + ) + Path(results_path_folder).mkdir( + parents=True, exist_ok=False + ) # avoid overwrite where possible + + patch_size = [w.value() for w in self.patch_size_widgets] + + logger.debug("Loading config...") + self.worker_config = config.TrainingWorkerConfig( + device=self.get_device(), + model_info=model_config, + weights_info=self.weights_config, + train_data_dict=self.data, + validation_percent=validation_percent, + max_epochs=self.epoch_choice.value(), + loss_function=self.get_loss(self.loss_choice.currentText()), + learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, + validation_interval=self.val_interval_choice.value(), + batch_size=self.batch_choice.slider_value, + results_path_folder=str(results_path_folder), + sampling=self.patch_choice.isChecked(), + num_samples=self.sample_choice_slider.slider_value, + sample_size=patch_size, + do_augmentation=self.augment_choice.isChecked(), + deterministic_config=deterministic_config, + ) # TODO(cyril) continue to put params in config + + return self.worker_config + def on_start(self): """Catches started signal from worker""" From d4d877b7a8db25eb2c09fe0b8b3bfd812760608c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:06:02 +0200 Subject: [PATCH 170/577] Added proper tests for train/infer --- .../_tests/test_plugin_inference.py | 36 ++++++++++++++----- napari_cellseg3d/_tests/test_training.py | 34 ++++++++++++------ napari_cellseg3d/code_models/workers.py | 4 +-- .../code_plugins/plugin_model_inference.py | 8 +++-- .../code_plugins/plugin_model_training.py | 10 ++++-- 5 files changed, 67 insertions(+), 25 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index d1264218..04305082 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -4,7 +4,10 @@ from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.code_plugins.plugin_model_inference import ( + InferenceResult, + Inferer, +) from napari_cellseg3d.config import MODEL_LIST @@ -28,16 +31,31 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(widget.config) - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() # takes too long on Github Actions - assert widget.worker is not None + assert widget.worker_config is not None + assert widget.model_info is not None + worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None + assert worker.config.model_info is not None + worker.config.layer = viewer.layers[0].data + assert worker.config.layer is not None + worker.log_parameters() + + res = next(worker.inference()) + assert isinstance(res, InferenceResult) + assert res.result.shape == (6, 6, 6) + + # def on_error(e): + # print(e) + # assert False + # with qtbot.waitSignal( + # signal=worker.finished, timeout=10000, raising=True + # ) as blocker: + # worker.error_signal.connect(on_error) + # blocker.connect(worker.errored) + # worker.inference() # takes too long on Github Actions # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 4d558363..080df419 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -3,7 +3,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d.code_plugins.plugin_model_training import ( + Trainer, + TrainingReport, +) from napari_cellseg3d.config import MODEL_LIST @@ -32,19 +35,30 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(worker_config) - - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() - assert widget.worker is not None + worker = widget._create_worker_from_config(worker_config) + worker.config.train_data_dict = [{"image": im_path, "label": im_path}] + worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.log_parameters() + res = next(worker.train()) + + assert isinstance(res, TrainingReport) + + # def on_error(e): + # print(e) + # assert False + # + # with qtbot.waitSignal( + # signal=widget.worker.finished, timeout=10000, raising=True + # ) as blocker: + # blocker.connect(widget.worker.errored) + # widget.worker.error_signal.connect(on_error) + # widget.worker.train() + # assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index e2e21363..6dd32c80 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -965,7 +965,7 @@ class TrainingWorker(GeneratorWorker): def __init__( self, - config: config.TrainingWorkerConfig, + worker_config: config.TrainingWorkerConfig, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train` @@ -1012,7 +1012,7 @@ def __init__( self._weight_error = False ############################################# - self.config = config + self.config = worker_config self.train_files = [] self.val_files = [] diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 3dd843d4..b619ac92 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -593,8 +593,12 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.InferenceWorkerConfig): - return InferenceWorker(worker_config=config) + def _create_worker_from_config( + self, worker_config: config.InferenceWorkerConfig + ): + if isinstance(worker_config, config.InfererConfig): + raise TypeError("Please provide a valid worker config object") + return InferenceWorker(worker_config=worker_config) def _set_worker_config(self) -> config.InferenceWorkerConfig: self.model_info = config.ModelInfo( diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index e11eb3de..2a131a5f 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -841,8 +841,14 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.TrainingWorkerConfig): - return TrainingWorker(config=config) + def _create_worker_from_config( + self, worker_config: config.TrainingWorkerConfig + ): + if isinstance(config, config.TrainerConfig): + raise TypeError( + "Expected a TrainingWorkerConfig, got a TrainerConfig" + ) + return TrainingWorker(worker_config=worker_config) def _set_worker_config(self) -> config.TrainingWorkerConfig: model_config = config.ModelInfo(name=self.model_choice.currentText()) From f0645158bef8494bd627ad063557defe90ac8702 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:31:36 +0200 Subject: [PATCH 171/577] Slight coverage increase --- napari_cellseg3d/_tests/test_plugin_inference.py | 13 ++----------- napari_cellseg3d/_tests/test_training.py | 1 + napari_cellseg3d/code_models/models/model_test.py | 2 +- napari_cellseg3d/code_models/workers.py | 6 +++--- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 04305082..c437ac83 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -39,23 +39,14 @@ def test_inference(make_napari_viewer, qtbot): assert widget.worker_config is not None assert widget.model_info is not None worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data + worker.config.post_process_config.instance.enabled = True assert worker.config.layer is not None worker.log_parameters() res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) - - # def on_error(e): - # print(e) - # assert False - # with qtbot.waitSignal( - # signal=worker.finished, timeout=10000, raising=True - # ) as blocker: - # worker.error_signal.connect(on_error) - # blocker.connect(worker.errored) - # worker.inference() # takes too long on Github Actions - # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 080df419..e7f1e07b 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -43,6 +43,7 @@ def test_training(make_napari_viewer, qtbot): worker = widget._create_worker_from_config(worker_config) worker.config.train_data_dict = [{"image": im_path, "label": im_path}] worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.config.max_epochs = 1 worker.log_parameters() res = next(worker.train()) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 1ccac3da..1cb52f06 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -8,7 +8,7 @@ class TestModel(nn.Module): def __init__(self, **kwargs): super().__init__() - self.linear = nn.Linear(1, 1) + self.linear = nn.Linear(8, 8) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 6dd32c80..8ddc7921 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1425,9 +1425,9 @@ def get_loader_func(num_samples): device = self.config.device - if model_name == "test": - self.quit() - yield TrainingReport(False) + # if model_name == "test": + # self.quit() + # yield TrainingReport(False) for epoch in range(self.config.max_epochs): # self.log("\n") From 1ecbf6ca2fc5777fcaf1dd22aa3c5e7aec941baf Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:45:47 +0200 Subject: [PATCH 172/577] Update test_plugin_inference.py --- napari_cellseg3d/_tests/test_plugin_inference.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index c437ac83..ca8e84d4 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,6 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.instance_segmentation import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import ( InferenceResult, @@ -44,6 +47,10 @@ def test_inference(make_napari_viewer, qtbot): assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data worker.config.post_process_config.instance.enabled = True + worker.config.post_process_config.instance.method = ( + INSTANCE_SEGMENTATION_METHOD_LIST["Watershed"]() + ) + assert worker.config.layer is not None worker.log_parameters() From 831758626be4907f5440c39fd684375f10aeff61 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 17 May 2023 11:41:39 +0200 Subject: [PATCH 173/577] Set window inference to 64 for WNet --- .../code_plugins/plugin_model_inference.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index b619ac92..c49b7d23 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -119,6 +119,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.model_choice.currentIndexChanged.connect( self._toggle_display_model_input_size ) + self.model_choice.currentIndexChanged.connect( + self._restrict_window_size_for_model + ) self.model_choice.setCurrentIndex(0) self.anisotropy_wdgt = ui.AnisotropyWidgets( @@ -150,9 +153,10 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.window_infer_box = ui.CheckBox("Use window inference") - self.window_infer_box.clicked.connect(self._toggle_display_window_size) + self.window_infer_box.toggled.connect(self._toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] + self._default_window_size = sizes_window.index("64") # ( # self.window_size_choice, # self.window_size_choice.label, @@ -167,7 +171,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(3) # set to 64 by default + self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -192,7 +196,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) - self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## @@ -299,6 +302,19 @@ def check_ready(self): return True return False + def _restrict_window_size_for_model(self): + """Sets the window size to a value that is compatible with the chosen model""" + if self.model_choice.currentText() == "WNet": + self.window_size_choice.setCurrentIndex(self._default_window_size) + self.window_size_choice.setDisabled(True) + self.window_infer_box.setChecked(True) + self.window_infer_box.setDisabled(True) + else: + self.window_size_choice.setDisabled(False) + self.window_infer_box.setDisabled(False) + self.window_infer_box.setChecked(False) + self.window_size_choice.setCurrentIndex(self._default_window_size) + def _toggle_display_model_input_size(self): if ( self.model_choice.currentText() == "SegResNet" From 2dea03970455624a1733b0cfd06b7a3623bb88b8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 17 May 2023 22:00:16 +0200 Subject: [PATCH 174/577] Update instance_segmentation.py --- napari_cellseg3d/code_models/instance_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index d1506f3d..f5066ebe 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -394,7 +394,7 @@ def sphericity(region): return ImageStats( volume, [region.centroid[0] for region in properties], - [region.centroid[0] for region in properties], + [region.centroid[1] for region in properties], [region.centroid[2] for region in properties], sphericity_ax, fill([volume_image.shape]), From 69b41d32cc9b0d821f11cff7549e18d821e2ac28 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 20 May 2023 09:22:52 +0200 Subject: [PATCH 175/577] Moved normalization to the correct place --- napari_cellseg3d/code_models/workers.py | 2 +- napari_cellseg3d/utils.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 8ddc7921..dd9e38e3 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -492,9 +492,9 @@ def model_output( logger.debug(f"inputs type : {inputs.dtype}") try: # outputs = model(inputs) + inputs = utils.remap_image(inputs) def model_output_wrapper(inputs): - inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 7e7a5c23..b6857a19 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -223,12 +223,18 @@ def normalize_max(image): def remap_image( - image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 + image: Union["np.ndarray", "torch.Tensor"], + new_max=100, + new_min=0, + prev_max=None, + prev_min=None, ): """Normalizes a numpy array or Tensor using the max and min value""" shape = image.shape image = image.flatten() - image = (image - image.min()) / (image.max() - image.min()) + im_max = prev_max if prev_max is not None else image.max() + im_min = prev_min if prev_min is not None else image.min() + image = (image - im_min) / (im_max - im_min) image = image * (new_max - new_min) + new_min image = image.reshape(shape) return image From e01240e069e307737148d09aa25268d418375419 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 11:09:48 +0200 Subject: [PATCH 176/577] Added auto-set dims for cropping --- napari_cellseg3d/code_plugins/plugin_crop.py | 38 +++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index a27b4baa..e3ea55f5 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -3,6 +3,7 @@ import napari import numpy as np from magicgui import magicgui +from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -43,6 +44,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") + self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -112,6 +114,8 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self._build() self._toggle_second_image_io_visibility() + self._check_image_list() + self.auto_set_dims() def _toggle_second_image_io_visibility(self): crop_2nd = self.crop_second_image_choice.isChecked() @@ -132,6 +136,16 @@ def _check_image_list(self): except IndexError: return + def auto_set_dims(self): + logger.debug(self.image_layer_loader.layer_name()) + data = self.image_layer_loader.layer_data() + if data is not None: + logger.debug("auto_set_dims : {}".format(data.shape)) + if len(data.shape) == 3: + for i, box in enumerate(self.crop_size_widgets): + logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") + box.setValue(floor(data.shape[i] / 2)) + def _build(self): """Build buttons in a layout and add them to the napari Viewer""" @@ -266,9 +280,9 @@ def _start(self): except ValueError as e: logger.warning(e) logger.warning( - "Could not remove cropping layer programmatically!" + "Could not remove the previous cropping layer programmatically." ) - logger.warning("Maybe layer has been removed by user?") + # logger.warning("Maybe layer has been removed by user?") self.results_path = Path(self.results_filewidget.text_field.text()) @@ -346,7 +360,7 @@ def add_isotropic_layer( layer.data, name=f"Scaled_{layer.name}", colormap=colormap, - contrast_limits=contrast_lim, + # contrast_limits=contrast_lim, opacity=opacity, scale=self.aniso_factors, visible=visible, @@ -481,8 +495,8 @@ def set_slice( """ "Update cropped volume position""" # self._check_for_empty_layer(highres_crop_layer, highres_crop_layer.data) - logger.debug(f"axis : {axis}") - logger.debug(f"value : {value}") + # logger.debug(f"axis : {axis}") + # logger.debug(f"value : {value}") idx = int(value) scale = np.asarray(highres_crop_layer.scale) @@ -496,6 +510,20 @@ def set_slice( cropy = self._crop_size_y cropz = self._crop_size_z + if i + cropx > im1_stack.shape[0]: + cropx = im1_stack.shape[0] - i + if j + cropy > im1_stack.shape[1]: + cropy = im1_stack.shape[1] - j + if k + cropz > im1_stack.shape[2]: + cropz = im1_stack.shape[2] - k + + logger.debug(f"cropx : {cropx}") + logger.debug(f"cropy : {cropy}") + logger.debug(f"cropz : {cropz}") + logger.debug(f"i : {i}") + logger.debug(f"j : {j}") + logger.debug(f"k : {k}") + highres_crop_layer.data = im1_stack[ i : i + cropx, j : j + cropy, k : k + cropz ] From 097ac6140aa4cca6bfaae658e27e6d63a202f59f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 12:19:37 +0200 Subject: [PATCH 177/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 0abcf387..60c25ccc 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,24 +1,24 @@ -from pathlib import Path - import numpy as np -from tifffile import imread +from numpy.random import PCG64, Generator from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, Utilities, ) +rand_gen = Generator(PCG64(12345)) + def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) - im_path = str(Path(__file__).resolve().parent / "res/test.tif") - image = imread(im_path) - view.add_image(image) - view.add_labels(image.astype(np.uint8)) + image = rand_gen.random((10, 10, 10)).astype(np.uint8) + image_layer = view.add_image(image, name="image") + label_layer = view.add_labels(image.astype(np.uint8), name="labels") view.window.add_dock_widget(widget) + view.dims.ndisplay = 3 for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): widget.utils_choice.setCurrentIndex(i) assert isinstance( @@ -29,4 +29,6 @@ def test_utils_plugin(make_napari_viewer): menu = widget.utils_widgets[i].instance_widgets.method_choice menu.setCurrentIndex(menu.currentIndex() + 1) + assert len(image_layer.data.shape) == 3 + assert len(label_layer.data.shape) == 3 widget.utils_widgets[i]._start() From 14408727d934a31dc5dd0d64da89b7572c231ae5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 15:50:18 +0200 Subject: [PATCH 178/577] More WNet - Added experimental .pt loading for jit models - More CRF tests - Optimized WNet by loading inference only --- napari_cellseg3d/_tests/test_models.py | 61 ++++++++++++------ napari_cellseg3d/code_models/crf.py | 8 ++- .../code_models/model_framework.py | 2 +- .../code_models/models/model_WNet.py | 18 +++--- .../code_models/models/wnet/model.py | 19 ++++-- napari_cellseg3d/code_models/workers.py | 62 ++++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 19 +++--- .../dev_scripts/correct_labels.py | 12 ++-- pyproject.toml | 1 + 9 files changed, 124 insertions(+), 78 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 4852f651..c67b3cab 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -2,9 +2,14 @@ import torch from numpy.random import PCG64, Generator -from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf +from napari_cellseg3d.code_models.crf import ( + CRFWorker, + correct_shape_for_crf, + crf_batch, + crf_with_config, +) from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss -from napari_cellseg3d.config import MODEL_LIST +from napari_cellseg3d.config import MODEL_LIST, CRFConfig rand_gen = Generator(PCG64(12345)) @@ -47,7 +52,38 @@ def test_soft_ncuts_loss(): assert 0 <= res <= 1 -def test_crf(qtbot): +def test_crf_batch(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_batch( + np.array([mock_image, mock_image, mock_image]), + np.array([mock_label, mock_label, mock_label]), + sa=config.sa, + sb=config.sb, + sg=config.sg, + w1=config.w1, + w2=config.w2, + ) + + assert isinstance(result, np.ndarray) + assert result.shape == (3, 2, dims, dims, dims) + + +def test_crf_config(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_with_config(mock_image, mock_label, config) + assert isinstance(result, np.ndarray) + assert result.shape == mock_label.shape + + +def test_crf_worker(qtbot): dims = 8 mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) @@ -60,20 +96,5 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=True - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() - - mock_image = mock_image[0] - mock_label = mock_label[0] - - crf = CRFWorker(mock_image, mock_label) - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=False - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() + result = next(crf._run_crf_job()) + on_yield(result) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 8c311059..b362246a 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,6 +54,8 @@ def correct_shape_for_crf(image, desired_dims=4): + logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}") + logger.debug(f"Image shape: {image.shape}") if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( @@ -62,6 +64,7 @@ def correct_shape_for_crf(image, desired_dims=4): image = np.squeeze(image, axis=0) elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) + logger.debug(f"Corrected image shape: {image.shape}") return image @@ -210,9 +213,12 @@ def _run_crf_job(self): if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - im = correct_shape_for_crf(self.labels[i]) + im = correct_shape_for_crf(self.images[i]) prob = correct_shape_for_crf(self.labels[i]) + logger.debug(f"image shape : {im.shape}") + logger.debug(f"labels shape : {prob.shape}") + yield crf( im, prob, diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 60644916..0296e0cf 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -281,7 +281,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth)", + filetype="Weights file (*.pth, *.pt)", ) if file[0] == self._default_weights_folder: return diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index cb5ef6d8..62142e73 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,8 +1,8 @@ # local -from napari_cellseg3d.code_models.models.wnet.model import WNet +from napari_cellseg3d.code_models.models.wnet.model import WNet_encoder -class WNet_(WNet): +class WNet_(WNet_encoder): use_default_training = False weights_file = "wnet.pth" @@ -24,13 +24,13 @@ def __init__( # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") - def forward(self, x): - """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference. - """ - return self.forward_encoder(x) - # enc = self.forward_encoder(x) - # return self.forward_decoder(enc) + # def forward(self, x): + # """Forward ENCODER pass of the W-Net model. + # Done this way to allow inference on the encoder only when called by sliding_window_inference. + # """ + # return self.forward_encoder(x) + # # enc = self.forward_encoder(x) + # # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 585ea0dd..a23084d0 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -16,6 +16,19 @@ ] +class WNet_encoder(nn.Module): + """WNet with encoder only.""" + + def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + super().__init__() + self.device = device + self.encoder = UNet(device, in_channels, num_classes, encoder=True) + + def forward(self, x): + """Forward pass of the W-Net model.""" + return self.forward_encoder(x) + + class WNet(nn.Module): """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. The model performs unsupervised segmentation of 3D images. @@ -36,13 +49,11 @@ def forward(self, x): def forward_encoder(self, x): """Forward pass of the encoder part of the W-Net model.""" - enc = self.encoder(x) - return enc + return self.encoder(x) def forward_decoder(self, enc): """Forward pass of the decoder part of the W-Net model.""" - dec = self.decoder(enc) - return dec + return self.decoder(enc) class UNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index dd9e38e3..8b3da42d 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -820,41 +820,43 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config - - # try: - self.log("Instantiating model...") - model = model_class( # FIXME test if works - input_img_size=[dims, dims, dims], - device=self.config.device, - num_classes=self.config.model_info.num_classes, - ) - # try: - model = model.to(self.config.device) - # except Exception as e: - # self.raise_error(e, "Issue loading model to device") - # logger.debug(f"model : {model}") - if model is None: - raise ValueError("Model is None") + if Path(weights_config.path).suffix == ".pt": + model = torch.jit.load(weights_config.path) # try: - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path else: - self.downloader.download_weights( - model_name, - model_class.weights_file, - ) - weights = str( - PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[dims, dims, dims], + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, + ) + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + ) - model.load_state_dict( # note that this is redefined in WNet_ - torch.load( - weights, - map_location=self.config.device, + model.load_state_dict( # note that this is redefined in WNet_ + torch.load( + weights, + map_location=self.config.device, + ) ) - ) - self.log("Done") + self.log("Done") # except Exception as e: # self.raise_error(e, "Issue loading weights") # except Exception as e: diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index e3ea55f5..74691e1f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,9 +1,9 @@ +from math import floor from pathlib import Path import napari import numpy as np from magicgui import magicgui -from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -44,7 +44,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") - self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) + self.image_layer_loader.layer_list.currentIndexChanged.connect( + self.auto_set_dims + ) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -140,10 +142,12 @@ def auto_set_dims(self): logger.debug(self.image_layer_loader.layer_name()) data = self.image_layer_loader.layer_data() if data is not None: - logger.debug("auto_set_dims : {}".format(data.shape)) + logger.debug(f"auto_set_dims : {data.shape}") if len(data.shape) == 3: for i, box in enumerate(self.crop_size_widgets): - logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") + logger.debug( + f"setting dim {i} to {floor(data.shape[i]/2)}" + ) box.setValue(floor(data.shape[i] / 2)) def _build(self): @@ -433,9 +437,8 @@ def _add_crop_sliders( box.value() for box in self.crop_size_widgets ] ############# - dims = [self._x, self._y, self._z] - [logger.debug(f"{dim}") for dim in dims] - logger.debug("SET DIMS ATTEMPT") + # [logger.debug(f"{dim}") for dim in dims] + # logger.debug("SET DIMS ATTEMPT") # if not self.create_new_layer.isChecked(): # self._x = x # self._y = y @@ -451,6 +454,8 @@ def _add_crop_sliders( # define crop sizes and boundaries for the image crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z] + # [logger.debug(f"{crop}") for crop in crop_sizes] + # logger.debug("SET CROP ATTEMPT") for i in range(len(crop_sizes)): if crop_sizes[i] > im1_stack.shape[i]: diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 4a7363b2..f413812d 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -363,9 +363,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") -# -# image_path = str(im_path / "volumes/images.tif") -# gt_labels_path = str(im_path / "labels/testing_im.tif") -# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) +if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/somatomotor") + + image_path = str(im_path / "volumes/c1images.tif") + gt_labels_path = str(im_path / "labels/c1labels.tif") + relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) diff --git a/pyproject.toml b/pyproject.toml index 082176b6..4c8ceed0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ docs = [ test = [ "pytest", "pytest_qt", + "pytest-cov", "coverage", "tox", "twine", From 71a08a6483cbae627c48bcc0e893adf43e789179 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 15:51:34 +0200 Subject: [PATCH 179/577] Update plugin_model_inference.py --- napari_cellseg3d/code_plugins/plugin_model_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index fb6fb71c..3684296e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -35,7 +35,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): * An option to load custom weights for the selected model (e.g. from training module) - * Post-processing : + * Additional options : * A box to select if data is anisotropic, if checked, asks for resolution in micron for each axis * A box to choose whether to threshold, if checked asks for a threshold between 0 and 1 @@ -406,7 +406,7 @@ def _build(self): ################################# # post proc group post_proc_group, post_proc_layout = ui.make_group( - "Post-processing", parent=self + "Additional options", parent=self ) self.thresholding_slider.container.setVisible(False) From 5ab0177e2493e0b6ccc3be6811fbb11884e48980 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 15:51:34 +0200 Subject: [PATCH 180/577] Update plugin_model_inference.py --- napari_cellseg3d/code_plugins/plugin_model_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index c49b7d23..74dc62e5 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -43,7 +43,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): * An option to load custom weights for the selected model (e.g. from training module) - * Post-processing : + * Additional options : * A box to select if data is anisotropic, if checked, asks for resolution in micron for each axis * A box to choose whether to threshold, if checked asks for a threshold between 0 and 1 @@ -444,7 +444,7 @@ def _build(self): ################################# # post proc group post_proc_group, post_proc_layout = ui.make_group( - "Post-processing", parent=self + "Additional options", parent=self ) self.thresholding_slider.container.setVisible(False) From 77b91537674ed454b201dbbce520056c978f7b7f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:12:07 +0200 Subject: [PATCH 181/577] Update crf test/deps for testing --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/_tests/test_models.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index fa6905d5..0911e358 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,6 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions + python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index c67b3cab..ec7462db 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -68,7 +68,6 @@ def test_crf_batch(): w2=config.w2, ) - assert isinstance(result, np.ndarray) assert result.shape == (3, 2, dims, dims, dims) @@ -79,7 +78,6 @@ def test_crf_config(): config = CRFConfig() result = crf_with_config(mock_image, mock_label, config) - assert isinstance(result, np.ndarray) assert result.shape == mock_label.shape @@ -91,7 +89,6 @@ def test_crf_worker(qtbot): crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): - assert isinstance(result, np.ndarray) assert len(result.shape) == 4 assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] From 0bd27406ea6629264eb4f8844fce7448efb497f8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:20:30 +0200 Subject: [PATCH 182/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0911e358..d09be5f0 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -87,6 +86,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From d382faf27f967e996f8edf8238ee8c0e11c3ee70 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:34:33 +0200 Subject: [PATCH 183/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d09be5f0..d36e03a3 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,6 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -86,7 +87,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From cd401734a5d438b46a9f10e249ed5cc08b61b572 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:28 +0200 Subject: [PATCH 184/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index abdbfd63..674226eb 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] ; opencv-python From 97f7bb798a8f0a5ebea7e511475177375aacaed5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:45 +0200 Subject: [PATCH 185/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d36e03a3..60bc5505 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox From 8af214288fa9e1c611edec08cf2a136e3829f512 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:50:44 +0200 Subject: [PATCH 186/577] Trying to fix tox install of pydensecrf --- .github/workflows/test_and_deploy.yml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 60bc5505..e9a66ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions -# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/tox.ini b/tox.ini index 674226eb..b4855dce 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] ; opencv-python From 00e68442a754f94e74f683be98a52bd95e9bd4c3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:23:51 +0200 Subject: [PATCH 187/577] Added experimental ONNX support for inference --- .../code_models/model_framework.py | 15 ++++---- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 34 ++++++++++++++++++- .../code_plugins/plugin_model_inference.py | 14 +++++++- pyproject.toml | 8 +++++ 5 files changed, 64 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 0296e0cf..f379ccb8 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -273,6 +273,14 @@ def get_available_models(): # self.lbl_model_path.setText(self.model_path) # # self.update_default() + def _update_weights_path(self, file): + if file[0] == self._default_weights_folder: + return + if file is not None and file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) + def _load_weights_path(self): """Show file dialog to set :py:attr:`model_path`""" @@ -283,12 +291,7 @@ def _load_weights_path(self): [self._default_weights_folder], filetype="Weights file (*.pth, *.pt)", ) - if file[0] == self._default_weights_folder: - return - if file is not None and file[0] != "": - self.weights_config.path = file[0] - self.weights_filewidget.text_field.setText(file[0]) - self._default_weights_folder = str(Path(file[0]).parent) + self._update_weights_path(file) @staticmethod def get_device(show=True): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index a23084d0..f98829bb 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -26,7 +26,7 @@ def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): def forward(self, x): """Forward pass of the W-Net model.""" - return self.forward_encoder(x) + return self.encoder(x) class WNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 8b3da42d..be88c835 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,6 +199,34 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files +class ONNXModelWrapper(torch.nn.Module): + """Class to replace torch model if ONNX is used""" + def __init__(self, file_location): + super().__init__() + try: + import onnx + import onnxruntime as ort + except ImportError as e: + logger.error("ONNX is not installed but ONNX model was loaded") + logger.error(e) + msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" + logger.error(msg) + raise ImportError(msg) + + self.ort_session = ort.InferenceSession( + file_location, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + + def forward(self, modeL_input): + outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + return torch.tensor(outputs[0]) + + def eval(self): + return True + + def to(self, device): + return True @dataclass class InferenceResult: @@ -821,9 +849,13 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config if Path(weights_config.path).suffix == ".pt": + self.log("Instantiating PyTorch jit model...") model = torch.jit.load(weights_config.path) # try: - else: + elif Path(weights_config.path).suffix == ".onnx": + self.log("Instantiating ONNX model...") + model = ONNXModelWrapper(weights_config.path) + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 74dc62e5..599ec5b3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING - +from pathlib import Path import numpy as np import pandas as pd @@ -348,6 +348,18 @@ def _toggle_display_window_size(self): """Show or hide window size choice depending on status of self.window_infer_box""" ui.toggle_visibility(self.window_infer_box, self.window_infer_params) + def _load_weights_path(self): + """Show file dialog to set :py:attr:`model_path`""" + + # logger.debug(self._default_weights_folder) + + file = ui.open_file_dialog( + self, + [self._default_weights_folder], + filetype="Weights file (*.pth, *.pt, *.onnx)", + ) + self._update_weights_path(file) + def _build(self): """Puts all widgets in a layout and adds them to the napari Viewer""" diff --git a/pyproject.toml b/pyproject.toml index 4c8ceed0..3b06f947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,3 +119,11 @@ test = [ "twine", "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] +onnx-cpu = [ + "onnx", + "onnxruntime" +] +onnx-gpu = [ + "onnx", + "onnxruntime-gpu" +] From ddc3d1fe4da2004f564313e228c06d1300a7e1b9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:47:48 +0200 Subject: [PATCH 188/577] Updated WNet for ONNX conversion --- .../code_models/models/wnet/model.py | 59 +++++++++++-------- napari_cellseg3d/code_models/workers.py | 9 ++- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index f98829bb..23584b30 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -59,18 +59,33 @@ def forward_decoder(self, enc): class UNet(nn.Module): """Half of the W-Net model, based on the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, encoder=True): + def __init__( + self, device, in_channels, out_channels, encoder=True, dropout=0.65 + ): super(UNet, self).__init__() self.device = device - self.in_b = InBlock(device, in_channels, 64) - self.conv1 = Block(device, 64, 128) - self.conv2 = Block(device, 128, 256) - self.conv3 = Block(device, 256, 512) - self.bot = Block(device, 512, 1024) - self.deconv1 = Block(device, 1024, 512) - self.deconv2 = Block(device, 512, 256) - self.deconv3 = Block(device, 256, 128) - self.out_b = OutBlock(device, 128, out_channels) + self.max_pool = nn.MaxPool3d(2) + self.in_b = InBlock(device, in_channels, 64, dropout=dropout) + self.conv1 = Block(device, 64, 128, dropout=dropout) + self.conv2 = Block(device, 128, 256, dropout=dropout) + self.conv3 = Block(device, 256, 512, dropout=dropout) + self.bot = Block(device, 512, 1024, dropout=dropout) + self.deconv1 = Block(device, 1024, 512, dropout=dropout) + self.conv_trans1 = nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + ) + self.deconv2 = Block(device, 512, 256, dropout=dropout) + self.conv_trans2 = nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + ) + self.deconv3 = Block(device, 256, 128, dropout=dropout) + self.conv_trans3 = nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + ) + self.out_b = OutBlock(device, 128, out_channels, dropout=dropout) + self.conv_trans_out = nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + ) self.sm = nn.Softmax(dim=1).to(device) self.encoder = encoder @@ -78,17 +93,15 @@ def __init__(self, device, in_channels, out_channels, encoder=True): def forward(self, x): """Forward pass of the U-Net model.""" in_b = self.in_b(x.to(self.device)) - c1 = self.conv1(nn.MaxPool3d(2)(in_b)) - c2 = self.conv2(nn.MaxPool3d(2)(c1)) - c3 = self.conv3(nn.MaxPool3d(2)(c2)) - x = self.bot(nn.MaxPool3d(2)(c3)) + c1 = self.conv1(self.max_pool(in_b)) + c2 = self.conv2(self.max_pool(c1)) + c3 = self.conv3(self.max_pool(c2)) + x = self.bot(self.max_pool(c3)) x = self.deconv1( torch.cat( [ c3, - nn.ConvTranspose3d( - 1024, 512, 2, stride=2, device=self.device - )(x), + self.conv_trans1(x), ], dim=1, ) @@ -97,9 +110,7 @@ def forward(self, x): torch.cat( [ c2, - nn.ConvTranspose3d( - 512, 256, 2, stride=2, device=self.device - )(x), + self.conv_trans2(x), ], dim=1, ) @@ -108,9 +119,7 @@ def forward(self, x): torch.cat( [ c1, - nn.ConvTranspose3d( - 256, 128, 2, stride=2, device=self.device - )(x), + self.conv_trans3(x), ], dim=1, ) @@ -119,9 +128,7 @@ def forward(self, x): torch.cat( [ in_b, - nn.ConvTranspose3d( - 128, 64, 2, stride=2, device=self.device - )(x), + self.conv_trans_out(x), ], dim=1, ) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index be88c835..bf6b8542 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -200,7 +200,7 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files class ONNXModelWrapper(torch.nn.Module): - """Class to replace torch model if ONNX is used""" + """Class to replace torch model by ONNX Runtime session""" def __init__(self, file_location): super().__init__() try: @@ -219,14 +219,17 @@ def __init__(self, file_location): ) def forward(self, modeL_input): + """Wraps ONNX output in a torch tensor""" outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) return torch.tensor(outputs[0]) def eval(self): - return True + """Dummy function to replace model.eval()""" + pass def to(self, device): - return True + """Dummy function to replace model.to(device)""" + pass @dataclass class InferenceResult: From 2536c82c33ace9e569ef80d2c10c1f50e9143120 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:56:45 +0200 Subject: [PATCH 189/577] Added dropout param --- .../code_models/models/wnet/model.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 23584b30..3416acb1 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -141,17 +141,17 @@ def forward(self, x): class InBlock(nn.Module): """Input block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(InBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -163,19 +163,19 @@ def forward(self, x): class Block(nn.Module): """Basic block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(Block, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), nn.Conv3d(in_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.Conv3d(out_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -187,21 +187,21 @@ def forward(self, x): class OutBlock(nn.Module): """Output block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(OutBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, out_channels, 1, device=device), ).to(device) def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) + return self.module(x.to(self.device)) \ No newline at end of file From 16e28cb2ea1cf16e4a5ff6155d5a2702ea6aae5c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 31 May 2023 16:13:42 +0200 Subject: [PATCH 190/577] Minor fixes in training --- napari_cellseg3d/code_models/workers.py | 8 ++++---- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 +++- napari_cellseg3d/interface.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index bf6b8542..c67ea523 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1531,13 +1531,13 @@ def get_loader_func(num_samples): or epoch + 1 == self.config.max_epochs ): model.eval() + self.log("Performing validation...") with torch.no_grad(): for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) - self.log("Performing validation...") try: with torch.no_grad(): val_outputs = sliding_window_inference( @@ -1606,8 +1606,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1620,7 +1620,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 2a131a5f..3e666dcc 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -169,6 +169,8 @@ def __init__( self.validation_values = [] # self.model_choice.setCurrentIndex(0) + wnet_index = self.model_choice.findText("WNet") + self.model_choice.removeItem(wnet_index) ################################ # interface @@ -813,7 +815,7 @@ def start(self): ) self._set_worker_config() - self.worker = TrainingWorker(config=self.worker_config) + self.worker = TrainingWorker(worker_config=self.worker_config) self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 061f4d1d..3effeb86 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1236,7 +1236,7 @@ def open_folder_dialog( logger.info(f"Default : {default_path}") return QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + "/.." + widget, "Open directory", default_path # + "/.." ) From 722da1bbeba9e8a4da431e2d01d5bc45e7354966 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:31:23 +0200 Subject: [PATCH 191/577] Fix weights file extension in inference + coverage - Remove unused scripts - More tests - Fixed weights type in inference --- .coveragerc | 7 + .gitignore | 1 + napari_cellseg3d/_tests/test_dock_widget.py | 1 + .../_tests/test_labels_correction.py | 8 +- .../_tests/test_plugin_inference.py | 2 + napari_cellseg3d/_tests/test_plugins.py | 21 ++ napari_cellseg3d/_tests/test_utils.py | 29 ++- .../code_models/model_framework.py | 28 +-- .../code_models/models/wnet/crf.py | 112 --------- napari_cellseg3d/code_plugins/plugin_crf.py | 6 +- .../code_plugins/plugin_metrics.py | 2 +- .../code_plugins/plugin_model_inference.py | 8 +- napari_cellseg3d/dev_scripts/convert.py | 26 -- napari_cellseg3d/dev_scripts/drafts.py | 15 -- .../dev_scripts/evaluate_labels.py | 2 +- .../extract_extra_channels_labels.py | 144 ----------- napari_cellseg3d/dev_scripts/view_brain.py | 8 - napari_cellseg3d/dev_scripts/view_sample.py | 29 --- .../dev_scripts/weight_conversion.py | 234 ------------------ napari_cellseg3d/interface.py | 6 +- napari_cellseg3d/utils.py | 2 +- tox.ini | 3 +- 22 files changed, 74 insertions(+), 620 deletions(-) create mode 100644 .coveragerc create mode 100644 napari_cellseg3d/_tests/test_plugins.py delete mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py delete mode 100644 napari_cellseg3d/dev_scripts/convert.py delete mode 100644 napari_cellseg3d/dev_scripts/drafts.py delete mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py delete mode 100644 napari_cellseg3d/dev_scripts/view_brain.py delete mode 100644 napari_cellseg3d/dev_scripts/view_sample.py delete mode 100644 napari_cellseg3d/dev_scripts/weight_conversion.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..038f3d5a --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[report] +exclude_lines = + if __name__ == .__main__.: + +[run] +omit = + napari_cellseg3d/setup.py diff --git a/.gitignore b/.gitignore index df67a187..7460d861 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,4 @@ notebooks/instance_test.ipynb !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif +cov.syspath.txt diff --git a/napari_cellseg3d/_tests/test_dock_widget.py b/napari_cellseg3d/_tests/test_dock_widget.py index 7737e540..8063c92b 100644 --- a/napari_cellseg3d/_tests/test_dock_widget.py +++ b/napari_cellseg3d/_tests/test_dock_widget.py @@ -11,6 +11,7 @@ def test_prepare(make_napari_viewer): viewer = make_napari_viewer() viewer.add_image(image) widget = Datamanager(viewer) + viewer.window.add_dock_widget(widget) widget.prepare(path_image, ".tif", "", False) diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..b4f13238 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -37,16 +37,16 @@ def test_correct_labels(): ) -def test_relabel(make_napari_viewer): - viewer = make_napari_viewer() +def test_relabel(): cl.relabel( str(image_path), str(labels_path), go_fast=True, - viewer=viewer, test=True, ) def test_evaluate_model_performance(): - el.evaluate_model_performance(labels, labels, print_details=True) + el.evaluate_model_performance( + labels, labels, print_details=True, visualize=False + ) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index ca8e84d4..1ae83102 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -57,3 +57,5 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) + + widget.on_yield(res) diff --git a/napari_cellseg3d/_tests/test_plugins.py b/napari_cellseg3d/_tests/test_plugins.py new file mode 100644 index 00000000..c58d26af --- /dev/null +++ b/napari_cellseg3d/_tests/test_plugins.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from napari_cellseg3d import plugins +from napari_cellseg3d.code_plugins import plugin_metrics as m + + +def test_all_plugins_import(make_napari_viewer): + plugins.napari_experimental_provide_dock_widget() + + +def test_plugin_metrics(make_napari_viewer): + viewer = make_napari_viewer() + w = m.MetricsUtils(viewer=viewer, parent=None) + viewer.window.add_dock_widget(w) + + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + labels_path = im_path + + w.image_filewidget.text_field = im_path + w.labels_filewidget.text_field = labels_path + w.compute_dice() diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 0b28183d..dc680b35 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,14 +1,15 @@ -import os from functools import partial +from pathlib import Path import numpy as np import torch from napari_cellseg3d import utils +from napari_cellseg3d.dev_scripts import thread_test def test_fill_list_in_between(): - list = [1, 2, 3, 4, 5, 6] + test_list = [1, 2, 3, 4, 5, 6] res = [ 1, "", @@ -30,11 +31,11 @@ def test_fill_list_in_between(): "", ] - assert utils.fill_list_in_between(list, 2, "") == res + assert utils.fill_list_in_between(test_list, 2, "") == res fill = partial(utils.fill_list_in_between, n=2, fill_value="") - assert fill(list) == res + assert fill(test_list) == res def test_align_array_sizes(): @@ -109,11 +110,19 @@ def test_normalize_x(): def test_parse_default_path(): - user_path = os.path.expanduser("~") - assert utils.parse_default_path([None]) == user_path + user_path = Path().home() + assert utils.parse_default_path([None]) == str(user_path) - path = ["C:/test/test", None, None] - assert utils.parse_default_path(path) == "C:/test/test" + test_path = "C:/test/test" + path = [test_path, None, None] + assert utils.parse_default_path(path) == test_path - path = ["C:/test/test", None, None, "D:/very/long/path/what/a/bore", ""] - assert utils.parse_default_path(path) == "D:/very/long/path/what/a/bore" + long_path = "D:/very/long/path/what/a/bore/ifonlytherewassomethingtohelpmenottypeitiallthetime" + path = [test_path, None, None, long_path, ""] + assert utils.parse_default_path(path) == long_path + + +def test_thread_test(make_napari_viewer): + viewer = make_napari_viewer() + w = thread_test.create_connected_widget(viewer) + viewer.window.add_dock_widget(w) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index f379ccb8..ddd9cd28 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -289,7 +289,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt)", + file_extension="Weights file (*.pth)", ) self._update_weights_path(file) @@ -311,31 +311,5 @@ def empty_cuda_cache(self): torch.cuda.empty_cache() logger.info("Attempt complete : Cache emptied") - # def update_default(self): # TODO add custom models - # """Update default path for smoother file dialogs, here with :py:attr:`~model_path` included""" - # - # if len(self.images_filepaths) != 0: - # from_images = str(Path(self.images_filepaths[0]).parent) - # else: - # from_images = None - # - # if len(self.labels_filepaths) != 0: - # from_labels = str(Path(self.labels_filepaths[0]).parent) - # else: - # from_labels = None - # - # possible_paths = [ - # path - # for path in [ - # from_images, - # from_labels, - # # self.model_path, - # self.results_path, - # ] - # if path is not None - # ] - # self._default_folders = possible_paths - # update if model_path is used again - def _build(self): raise NotImplementedError("Should be defined in children classes") diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py deleted file mode 100644 index 004db3a1..00000000 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Implements the CRF post-processing step for the W-Net. -Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - -Also uses research from: -Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials -Philipp Krähenbühl and Vladlen Koltun -NIPS 2011 - -Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. -""" - -import numpy as np -import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - create_pairwise_bilateral, - create_pairwise_gaussian, - unary_from_softmax, -) - -__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" -__credits__ = [ - "Yves Paychère", - "Colin Hofmann", - "Cyril Achard", - "Philipp Krähenbühl", - "Vladlen Koltun", - "Liang-Chieh Chen", - "George Papandreou", - "Iasonas Kokkinos", - "Kevin Murphy", - "Alan L. Yuille", - "Xide Xia", - "Brian Kulis", - "Lucas Beyer", -] - - -def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): - """CRF post-processing step for the W-Net, applied to a batch of images. - - Args: - images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. - probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. - """ - - return np.stack( - [ - crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) - for i in range(images.shape[0]) - ], - axis=0, - ) - - -def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): - """Implements the CRF post-processing step for the W-Net. - Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - Implemented using the pydensecrf library. - - Args: - image (np.ndarray): Array of shape (C, H, W, D) containing the input image. - prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. - """ - d = dcrf.DenseCRF( - image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] - ) - # print(f"Image shape : {image.shape}") - # print(f"Prob shape : {prob.shape}") - # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels - - # Get unary potentials from softmax probabilities - U = unary_from_softmax(prob) - d.setUnaryEnergy(U) - - # Generate pairwise potentials - featsGaussian = create_pairwise_gaussian( - sdims=(sg, sg, sg), shape=image.shape[1:] - ) # image.shape) - featsBilateral = create_pairwise_bilateral( - sdims=(sa, sa, sa), - schan=tuple([sb for i in range(image.shape[0])]), - img=image, - chdim=-1, - ) - - # Add pairwise potentials to the CRF - compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( - [1 for i in range(prob.shape[0])] - # , dtype=np.float32 - ) - d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) - d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) - - # Run inference - Q = d.inference(n_iter) - - return np.array(Q).reshape( - (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) - ) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index d8407a0f..76194e87 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial from pathlib import Path @@ -277,7 +278,10 @@ def _on_start(self): def _on_finish(self): self.worker = None - self.start_button.setText("Start") + with contextlib.suppress(RuntimeError): + self.start_button.setText("Start") + + # should only happen when testing def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 2a6e713c..1dc5e7de 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -23,7 +23,7 @@ class MetricsUtils(BasePluginFolder): """Plugin to evaluate metrics between two sets of labels, ground truth and prediction""" - def __init__(self, viewer: "napari.viewer.Viewer", parent): + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): """Creates a MetricsUtils widget for computing and plotting dice metrics between labels. Args: viewer: viewer to display the widget in diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 599ec5b3..256cffa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING -from pathlib import Path + import numpy as np import pandas as pd @@ -171,7 +171,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default + self.window_size_choice.setCurrentIndex( + self._default_window_size + ) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -356,7 +358,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt, *.onnx)", + file_extension="Weights file (*.pth *.pt *.onnx)", ) self._update_weights_path(file) diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py deleted file mode 100644 index 641de627..00000000 --- a/napari_cellseg3d/dev_scripts/convert.py +++ /dev/null @@ -1,26 +0,0 @@ -import glob -import os - -import numpy as np -from tifffile import imread, imwrite - -# input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" -# output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" - -input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/labels" -output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/lab_sem" - -filenames = [] -paths = [] -filetype = ".tif" -for filename in glob.glob(os.path.join(input_seg_path, "*" + filetype)): - paths.append(filename) - filenames.append(os.path.basename(filename)) - # print(os.path.basename(filename)) -for file in paths: - image = imread(file) - - image[image >= 1] = 1 - image = image.astype(np.uint16) - - imwrite(output_seg_path + "/" + os.path.basename(file), image) diff --git a/napari_cellseg3d/dev_scripts/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py deleted file mode 100644 index cdd02256..00000000 --- a/napari_cellseg3d/dev_scripts/drafts.py +++ /dev/null @@ -1,15 +0,0 @@ -import napari -import numpy as np -from magicgui import magicgui -from napari.types import ImageData, LabelsData - - -@magicgui(call_button="Run Threshold") -def threshold(image: ImageData, threshold: int = 75) -> LabelsData: - """Threshold an image and return a mask.""" - return (image > threshold).astype(int) - - -viewer = napari.view_image(np.random.randint(0, 100, (64, 64))) -viewer.window.add_dock_widget(threshold) -threshold() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index ee9919b6..2830f4e7 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -127,7 +127,7 @@ def evaluate_model_performance( ) if visualize: - viewer = napari.Viewer() + viewer = napari.Viewer(ndisplay=3) viewer.add_labels(labels, name="ground truth") viewer.add_labels(model_labels, name="model's labels") found_model = np.where( diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py deleted file mode 100644 index 70ee10b6..00000000 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ /dev/null @@ -1,144 +0,0 @@ -import numpy as np -from skimage.filters import threshold_otsu -from skimage.segmentation import expand_labels -from tqdm import tqdm - - -def extract_labels_from_channels( # TODO add separate channels results - nuclei_labels: np.array, - extra_channels: list, - radius: int = 4, - threshold_factor=2, - viewer=None, -): - """ - Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. - Args: - nuclei_labels (np.array): labels for the nuclei - extra_channels (list): channels arrays to extract labels from - radius: radius in which the approximation is made - - Returns: - A list of extracted labels for each extra channel - """ - labeled_channels = [] - contrasted_channels = [] - for channel in extra_channels: - channel = (channel - np.min(channel)) / ( - np.max(channel) - np.min(channel) - ) - threshold_brightness = threshold_otsu(channel) * threshold_factor - channel_contrasted = np.where( - channel > threshold_brightness, channel, 0 - ) - contrasted_channels.append(channel_contrasted) - if viewer is not None: - viewer.add_image( - channel_contrasted, - name="channel_contrasted", - colormap="viridis", - ) - for label_id in tqdm(np.unique(nuclei_labels)): - if label_id == 0: - continue - label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) - expanded = expand_labels(label_nucleus, distance=radius) - restricted = np.where(expanded != 0, nuclei_labels, 0) - overlap = np.where(restricted != label_id, restricted, 0) - - for i, channel in enumerate(contrasted_channels): - label_contrasted = np.where(expanded != 0, channel, 0) - if overlap.any() != 0: - max_labeled = 0 - for overlap_id in np.unique(overlap): - if overlap_id == 0: - continue - assigned_pixels = np.count_nonzero( - np.where(overlap == overlap_id, channel, 0) - ) - if assigned_pixels > max_labeled: - max_labeled = assigned_pixels - max_label_id = overlap_id - if label_id != max_label_id: - labeled_channels.append( - np.zeros_like(label_contrasted) - ) - else: - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels.append(labeled_channel) - if ( - np.count_nonzero(labeled_channel) > 0 - and viewer is not None - ): - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) - - cat_labels = np.zeros_like(nuclei_labels) - for labels in np.unique(labeled_channels): - if labels == 0: - continue - cat_labels += np.where(labels != 0, labels, 0) - return cat_labels - - -if __name__ == "__main__": - from pathlib import Path - - import napari - from tifffile import imread - - image_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - ) - # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - nuclei_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" - ) - extra_channels_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" - ) - extra_channels = [ - imread(str(path)) - for path in extra_channels_path.glob( - "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" - ) - ] - labels = imread(str(image_path)) - viewer = napari.Viewer() - - shift = 0 - viewer.add_image( - imread(str(nuclei_path))[ - shift : 32 + shift, shift : 32 + shift, shift : 32 + shift - ], - name="nuclei", - ) - viewer.add_labels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - [ - viewer.add_image( - channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - for channel in extra_channels - ] - - labeled_channels = extract_labels_from_channels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], - [ - c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - for c in extra_channels - ], - radius=4, - viewer=viewer, - ) - - viewer.add_labels(labeled_channels) - # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] - # expanded = expand_labels(labels, 4) - # viewer.add_labels(expanded) - napari.run() diff --git a/napari_cellseg3d/dev_scripts/view_brain.py b/napari_cellseg3d/dev_scripts/view_brain.py deleted file mode 100644 index 145d4e45..00000000 --- a/napari_cellseg3d/dev_scripts/view_brain.py +++ /dev/null @@ -1,8 +0,0 @@ -import napari -from tifffile import imread - -y = imread("/Users/maximevidal/Documents/3drawdata/wholebrain.tif") - -with napari.gui_qt(): - viewer = napari.Viewer() - viewer.add_image(y, contrast_limits=[0, 2000], multiscale=False) diff --git a/napari_cellseg3d/dev_scripts/view_sample.py b/napari_cellseg3d/dev_scripts/view_sample.py deleted file mode 100644 index 8e87f85c..00000000 --- a/napari_cellseg3d/dev_scripts/view_sample.py +++ /dev/null @@ -1,29 +0,0 @@ -import napari -from tifffile import imread - -# Visual -x = imread( - "/Users/maximevidal/Documents/trailmap/data/no-edge-validation/visual-original/volumes/images.tif" -) -y_semantic = imread( - "/Users/maximevidal/Documents/trailmap/data/testing/seg-visual1-single/image.tif" -) -y_instance = imread( - "/Users/maximevidal/Documents/trailmap/data/instance-testing/test-visual-5.tiff" -) -y_true = imread( - "/Users/maximevidal/Documents/3drawdata/visual/labels/labels.tif" -) - -# SM -# x = imread("/Users/maximevidal/Documents/trailmap/data/no-edge-validation/validation-original/volumes/c5images.tif") -# y = imread("/Users/maximevidal/Documents/trailmap/data/instance-testing/test1.tiff") -# y_true = imread("/Users/maximevidal/Documents/3drawdata/somatomotor/labels/c5labels.tif") - -with napari.gui_qt(): - viewer = napari.view_image( - x, colormap="inferno", contrast_limits=[200, 1000] - ) - viewer.add_image(y_semantic, name="semantic_predictions", opacity=0.5) - viewer.add_labels(y_instance, name="instance_predictions", seed=0.6) - viewer.add_labels(y_true, name="truth", seed=0.6) diff --git a/napari_cellseg3d/dev_scripts/weight_conversion.py b/napari_cellseg3d/dev_scripts/weight_conversion.py deleted file mode 100644 index 6cdb9c43..00000000 --- a/napari_cellseg3d/dev_scripts/weight_conversion.py +++ /dev/null @@ -1,234 +0,0 @@ -import collections -import os - -import torch - -from napari_cellseg3d.code_models.models import get_net -from napari_cellseg3d.code_models.models.unet.model import UNet3D - -# not sure this actually works when put here - - -def weight_translate(k, w): - k = key_translate(k) - if k.endswith(".weight"): - if w.dim() == 2: - w = w.t() - elif w.dim() == 1: - pass - elif w.dim() == 4: - w = w.permute(3, 2, 0, 1) - else: - assert w.dim() == 5 - w = w.permute(4, 3, 0, 1, 2) - return w - - -def key_translate(k): - k = ( - k.replace( - "conv3d/kernel:0", - "encoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization/gamma:0", - "encoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization/beta:0", - "encoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_1/kernel:0", - "encoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_1/gamma:0", - "encoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_1/beta:0", - "encoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_2/kernel:0", - "encoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_2/gamma:0", - "encoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_2/beta:0", - "encoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_3/kernel:0", - "encoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_3/gamma:0", - "encoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_3/beta:0", - "encoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_4/kernel:0", - "encoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_4/gamma:0", - "encoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_4/beta:0", - "encoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_5/kernel:0", - "encoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_5/gamma:0", - "encoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_5/beta:0", - "encoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_6/kernel:0", - "encoders.3.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_6/gamma:0", - "encoders.3.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_6/beta:0", - "encoders.3.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_7/kernel:0", - "encoders.3.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_7/gamma:0", - "encoders.3.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_7/beta:0", - "encoders.3.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_8/kernel:0", - "decoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_8/gamma:0", - "decoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_8/beta:0", - "decoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_9/kernel:0", - "decoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_9/gamma:0", - "decoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_9/beta:0", - "decoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_10/kernel:0", - "decoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_10/gamma:0", - "decoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_10/beta:0", - "decoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_11/kernel:0", - "decoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_11/gamma:0", - "decoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_11/beta:0", - "decoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_12/kernel:0", - "decoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_12/gamma:0", - "decoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_12/beta:0", - "decoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_13/kernel:0", - "decoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_13/gamma:0", - "decoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_13/beta:0", - "decoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace("conv3d_14/kernel:0", "final_conv.weight") - .replace("conv3d_14/bias:0", "final_conv.bias") - ) - return k - - -model = get_net() -base_path = os.path.abspath(__file__ + "/..") -weights_path = base_path + "/data/model-weights/trailmap_model.hdf5" -model.load_weights(weights_path) - -for i, l in enumerate(model.layers): - print(i, l) - print( - "L{}: {}".format( - i, ", ".join(str(w.shape) for w in model.layers[i].weights) - ) - ) - -weights_pt = collections.OrderedDict( - [(w.name, torch.from_numpy(w.numpy())) for w in model.trainable_variables] -) -torch.save(weights_pt, base_path + "/data/model-weights/trailmaptorch.pt") -torch_weights = torch.load(base_path + "/data/model-weights/trailmaptorch.pt") -param_dict = { - key_translate(k): weight_translate(k, v) for k, v in torch_weights.items() -} - -trailmap_model = UNet3D(1, 1) -torchparam = trailmap_model.state_dict() -for k, v in torchparam.items(): - print("{:20s} {}".format(k, v.shape)) - -trailmap_model.load_state_dict(param_dict, strict=False) -torch.save( - trailmap_model.state_dict(), - base_path + "/data/model-weights/trailmaptorchpretrained.pt", -) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 3effeb86..06a2190a 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1208,7 +1208,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, possible_paths: list = (), - filetype: str = "Image file (*.tif *.tiff)", + file_extension: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1217,14 +1217,14 @@ def open_file_dialog( possible_paths (str): Paths that may have been chosen before, can be a string or an array of strings containing the paths load_as_folder (bool): Whether to open a folder or a single file. If True, will allow opening folder as a single file (2D stack interpreted as 3D) - filetype (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` + file_extension (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` """ default_path = utils.parse_default_path(possible_paths) return QFileDialog.getOpenFileName( - widget, "Choose file", default_path, filetype + widget, "Choose file", default_path, file_extension ) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index b6857a19..2f6094b7 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -520,7 +520,7 @@ def parse_default_path(possible_paths): # ] print(default_paths) if len(default_paths) == 0: - return str(Path.home()) + return str(Path().home()) default_path = max(default_paths, key=len) return str(default_path) diff --git a/tox.ini b/tox.ini index b4855dce..4c1c3e51 100644 --- a/tox.ini +++ b/tox.ini @@ -39,5 +39,6 @@ deps = git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] ; opencv-python - +extras = crf +usedevelop = true commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From e8801a83da2bbb4f6c7cd37f5b5e8580d8a02fa2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:41:07 +0200 Subject: [PATCH 192/577] Run all hooks --- .../_tests/test_plugin_inference.py | 5 ++++- .../code_models/models/model_TRAILMAP.py | 15 +++++--------- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 20 +++++++++++-------- napari_cellseg3d/code_plugins/plugin_base.py | 15 ++++++-------- .../code_plugins/plugin_helper.py | 4 +++- .../code_plugins/plugin_utilities.py | 5 ++++- napari_cellseg3d/dev_scripts/thread_test.py | 6 ++++-- pyproject.toml | 2 +- 9 files changed, 40 insertions(+), 34 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1ae83102..1e486c14 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,9 +34,12 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() + widget.model_choice.setCurrentIndex(-1) + assert widget.window_infer_box.isChecked() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + widget.model_choice.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 8a108e37..e6bbad55 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -39,13 +39,12 @@ def forward(self, x): up8 = self.up8(torch.cat([up7, conv0], 1)) # l1 # print(up8.shape) - out = self.out(up8) + return self.out(up8) # print("out:") # print(out.shape) - return out def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -56,10 +55,9 @@ def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.ReLU(), nn.MaxPool3d(2), ) - return encode def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -69,10 +67,9 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.BatchNorm3d(out_ch), nn.ReLU(), ) - return encode def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - decode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -85,13 +82,11 @@ def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): out_ch, out_ch, kernel_size=kernel_size, stride=(2, 2, 2) ), ) - return decode def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - out = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) - return out class TRAILMAP_(TRAILMAP): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 3416acb1..2900b89c 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -204,4 +204,4 @@ def __init__(self, device, in_channels, out_channels, dropout=0.65): def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) \ No newline at end of file + return self.module(x.to(self.device)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c67ea523..245e6f02 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,28 +199,31 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files + class ONNXModelWrapper(torch.nn.Module): """Class to replace torch model by ONNX Runtime session""" + def __init__(self, file_location): super().__init__() try: - import onnx import onnxruntime as ort except ImportError as e: logger.error("ONNX is not installed but ONNX model was loaded") logger.error(e) msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" logger.error(msg) - raise ImportError(msg) + raise ImportError(msg) from e self.ort_session = ort.InferenceSession( file_location, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) def forward(self, modeL_input): """Wraps ONNX output in a torch tensor""" - outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + outputs = self.ort_session.run( + None, {"input": modeL_input.cpu().numpy()} + ) return torch.tensor(outputs[0]) def eval(self): @@ -231,6 +234,7 @@ def to(self, device): """Dummy function to replace model.to(device)""" pass + @dataclass class InferenceResult: """Class to record results of a segmentation job""" @@ -858,7 +862,7 @@ def inference(self): elif Path(weights_config.path).suffix == ".onnx": self.log("Instantiating ONNX model...") model = ONNXModelWrapper(weights_config.path) - else: # assume is .pth + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], @@ -1606,8 +1610,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1620,7 +1624,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 26da7a42..cfa3f0d7 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -227,17 +227,16 @@ def _show_filetype_choice(self): def _show_file_dialog(self): """Open file dialog and process path depending on single file/folder loading behaviour""" if self.load_as_stack_choice.isChecked(): - folder = ui.open_folder_dialog( + choice = ui.open_folder_dialog( self, self._default_path, filetype=f"Image file (*{self.filetype_choice.currentText()})", ) - return folder else: f_name = ui.open_file_dialog(self, self._default_path) - f_name = str(f_name[0]) - self.filetype = str(Path(f_name).suffix) - return f_name + choice = str(f_name[0]) + self.filetype = str(Path(choice).suffix) + return choice def _show_dialog_images(self): """Show file dialog and set image path""" @@ -291,16 +290,14 @@ def _make_close_button(self): return btn def _make_prev_button(self): - btn = ui.Button( + return ui.Button( "Previous", lambda: self.setCurrentIndex(self.currentIndex() - 1) ) - return btn def _make_next_button(self): - btn = ui.Button( + return ui.Button( "Next", lambda: self.setCurrentIndex(self.currentIndex() + 1) ) - return btn def remove_from_viewer(self): """Removes the widget from the napari window. diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index f8ac18ef..00104bb5 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -1,6 +1,8 @@ import pathlib +from typing import TYPE_CHECKING -import napari +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import QSize diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 868dd279..6e1a606a 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -1,4 +1,7 @@ -import napari +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import qInstallMessageHandler diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 20668125..a48f6db0 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -1,8 +1,8 @@ import time import napari -import numpy as np from napari.qt.threading import thread_worker +from numpy.random import PCG64, Generator from qtpy.QtWidgets import ( QGridLayout, QLabel, @@ -13,6 +13,8 @@ QWidget, ) +rand_gen = Generator(PCG64(12345)) + @thread_worker def two_way_communication_with_args(start, end): @@ -129,7 +131,7 @@ def on_finish(): if __name__ == "__main__": - viewer = napari.view_image(np.random.rand(512, 512)) + viewer = napari.view_image(rand_gen.random(512, 512)) w = create_connected_widget(viewer) viewer.window.add_dock_widget(w) diff --git a/pyproject.toml b/pyproject.toml index 3b06f947..332bf768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ select = [ ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) # and 'G004' (do not use f-strings in logging) -ignore = ["E501", "E741", "G004"] +ignore = ["E501", "E741", "G004", "A003"] exclude = [ ".bzr", ".direnv", From 5abc574d6181d105fd101ca4e5294ebf1fe6b9a1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 11:27:58 +0200 Subject: [PATCH 193/577] Fix inference testing --- .../_tests/test_plugin_inference.py | 13 +++++++----- .../code_models/models/model_test.py | 20 +++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1e486c14..779f5094 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,12 +34,15 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - widget.model_choice.setCurrentIndex(-1) + widget.model_choice.setCurrentText("WNet") + widget._restrict_window_size_for_model() assert widget.window_infer_box.isChecked() + assert widget.window_size_choice.currentText() == "64" - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(-1) + test_model_name = "test" + MODEL_LIST[test_model_name] = TestModel + widget.model_choice.addItem(test_model_name) + widget.model_choice.setCurrentText(test_model_name) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None @@ -59,6 +62,6 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) - assert res.result.shape == (6, 6, 6) + assert res.result.shape == (8, 8, 8) widget.on_yield(res) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 1cb52f06..28f3a05b 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -20,13 +20,13 @@ def forward(self, x): # return val_inputs -# if __name__ == "__main__": -# -# model = TestModel() -# model.train() -# model.zero_grad() -# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR -# torch.save( -# model.state_dict(), -# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" -# ) +if __name__ == "__main__": + model = TestModel() + model.train() + model.zero_grad() + from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR + + torch.save( + model.state_dict(), + PRETRAINED_WEIGHTS_DIR + f"/{TestModel.weights_file}", + ) From 0f1cf620f122032fb474ebb00b8eb01447faac20 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 13:45:50 +0200 Subject: [PATCH 194/577] Changed anisotropy calculation --- napari_cellseg3d/_tests/test_interface.py | 8 +++++++- napari_cellseg3d/interface.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py index be811721..08e0e675 100644 --- a/napari_cellseg3d/_tests/test_interface.py +++ b/napari_cellseg3d/_tests/test_interface.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.interface import Log +from napari_cellseg3d.interface import AnisotropyWidgets, Log def test_log(qtbot): @@ -12,3 +12,9 @@ def test_log(qtbot): assert log.toPlainText() == "\ntest2" qtbot.add_widget(log) + + +def test_zoom_factor(): + resolution = [10.0, 10.0, 5.0] + zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution) + assert zoom == [1, 1, 0.5] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 06a2190a..d2ec5789 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -735,8 +735,8 @@ def anisotropy_zoom_factor(aniso_res): """ - base = min(aniso_res) - return [base / res for res in aniso_res] + base = max(aniso_res) + return [res / base for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" From 98c05fcc0b0682192cd68199bb9b43ee66225e07 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 11:09:12 +0200 Subject: [PATCH 195/577] Updated hooks --- .pre-commit-config.yaml | 14 +- napari_cellseg3d/_tests/test_plugin_utils.py | 6 +- .../_tests/test_weight_download.py | 6 +- .../code_models/model_framework.py | 15 +- .../code_models/model_instance_seg.py | 8 +- napari_cellseg3d/code_models/model_workers.py | 95 ++++++------ .../code_models/models/unet/buildingblocks.py | 5 +- .../code_models/models/unet/model.py | 6 +- napari_cellseg3d/code_plugins/plugin_base.py | 3 +- .../code_plugins/plugin_convert.py | 138 +++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 5 +- .../code_plugins/plugin_helper.py | 6 +- .../code_plugins/plugin_metrics.py | 3 +- .../code_plugins/plugin_model_inference.py | 23 +-- .../code_plugins/plugin_model_training.py | 23 +-- .../code_plugins/plugin_review.py | 15 +- .../code_plugins/plugin_review_dock.py | 11 +- .../code_plugins/plugin_utilities.py | 18 +-- napari_cellseg3d/config.py | 7 +- .../dev_scripts/artefact_labeling.py | 3 +- napari_cellseg3d/dev_scripts/convert.py | 3 +- .../dev_scripts/correct_labels.py | 8 +- napari_cellseg3d/dev_scripts/drafts.py | 3 +- napari_cellseg3d/dev_scripts/thread_test.py | 16 +- napari_cellseg3d/interface.py | 59 ++++---- napari_cellseg3d/utils.py | 8 +- pyproject.toml | 43 +++++- 27 files changed, 286 insertions(+), 264 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d6bdc58e..f9fe2853 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,13 +2,17 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: +# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] + - id: check-yaml + - id: check-added-large-files + - id: check-toml +# - repo: https://github.com/pycqa/isort +# rev: 5.12.0 +# hooks: +# - id: isort +# args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index cbfd97b2..0abcf387 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -3,8 +3,10 @@ import numpy as np from tifffile import imread -from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS +from napari_cellseg3d.code_plugins.plugin_utilities import ( + UTILITIES_WIDGETS, + Utilities, +) def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index b8f0d748..d8886a56 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,5 +1,7 @@ -from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR -from napari_cellseg3d.code_models.model_workers import WeightsDownloader +from napari_cellseg3d.code_models.model_workers import ( + WEIGHTS_DIR, + WeightsDownloader, +) # DISABLED, causes GitHub actions to freeze diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index b3121cf4..2cc4265e 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -5,13 +5,11 @@ import torch # Qt -from qtpy.QtWidgets import QProgressBar -from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QProgressBar, QSizePolicy # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder warnings.formatwarning = utils.format_Warning @@ -286,11 +284,10 @@ def _load_weights_path(self): ) if file[0] == self._default_weights_folder: return - if file is not None: - if file[0] != "": - self.weights_config.path = file[0] - self.weights_filewidget.text_field.setText(file[0]) - self._default_weights_folder = str(Path(file[0]).parent) + if file is not None and file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) @staticmethod def get_device(show=True): diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index c72bafe9..60f8bbda 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,8 +4,7 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label -from skimage.measure import regionprops +from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed @@ -14,9 +13,8 @@ from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -561,7 +559,7 @@ def _build(self): self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): + for name in self.instance_widgets: if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 636f7acd..30d37bbd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -2,44 +2,46 @@ from dataclasses import dataclass from math import ceil from pathlib import Path -from typing import List -from typing import Optional +from typing import List, Optional import numpy as np import torch # MONAI -from monai.data import CacheDataset -from monai.data import DataLoader -from monai.data import Dataset -from monai.data import decollate_batch -from monai.data import pad_list_data_collate -from monai.data import PatchDataset +from monai.data import ( + CacheDataset, + DataLoader, + Dataset, + PatchDataset, + decollate_batch, + pad_list_data_collate, +) from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric -from monai.transforms import AddChannel -from monai.transforms import AsDiscrete -from monai.transforms import Compose -from monai.transforms import EnsureChannelFirstd -from monai.transforms import EnsureType -from monai.transforms import EnsureTyped -from monai.transforms import LoadImaged -from monai.transforms import Orientationd -from monai.transforms import Rand3DElasticd -from monai.transforms import RandAffined -from monai.transforms import RandFlipd -from monai.transforms import RandRotate90d -from monai.transforms import RandShiftIntensityd -from monai.transforms import RandSpatialCropSamplesd -from monai.transforms import SpatialPad -from monai.transforms import SpatialPadd -from monai.transforms import ToTensor -from monai.transforms import Zoom +from monai.transforms import ( + AddChannel, + AsDiscrete, + Compose, + EnsureChannelFirstd, + EnsureType, + EnsureTyped, + LoadImaged, + Orientationd, + Rand3DElasticd, + RandAffined, + RandFlipd, + RandRotate90d, + RandShiftIntensityd, + RandSpatialCropSamplesd, + SpatialPad, + SpatialPadd, + ToTensor, + Zoom, +) from monai.utils import set_determinism # threads -from napari.qt.threading import GeneratorWorker -from napari.qt.threading import WorkerBaseSignals +from napari.qt.threading import GeneratorWorker, WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -47,11 +49,12 @@ from tqdm import tqdm # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ImageStats -from napari_cellseg3d.code_models.model_instance_seg import volume_stats +from napari_cellseg3d.code_models.model_instance_seg import ( + ImageStats, + volume_stats, +) logger = utils.LOGGER @@ -112,7 +115,7 @@ def show_progress(count, block_size, total_size): with open(json_path) as f: neturls = json.load(f) - if model_name in neturls.keys(): + if model_name in neturls: url = neturls[model_name] response = urllib.request.urlopen(url) @@ -282,10 +285,11 @@ def log_parameters(self): f"Thresholding is enabled at {config.post_process_config.thresholding.threshold_value}" ) - if config.sliding_window_config.is_enabled(): - status = "enabled" - else: - status = "disabled" + status = ( + "enabled" + if config.sliding_window_config.is_enabled() + else "disabled" + ) self.log(f"Window inference is {status}\n") if status == "enabled": @@ -454,10 +458,9 @@ def model_output(inputs): self.config.model_info.get_model().get_output(model, inputs) ) - if self.config.keep_on_cpu: - dataset_device = "cpu" - else: - dataset_device = self.config.device + dataset_device = ( + "cpu" if self.config.keep_on_cpu else self.config.device + ) window_size = self.config.sliding_window_config.window_size window_overlap = self.config.sliding_window_config.window_overlap @@ -1055,10 +1058,7 @@ def train(self): do_sampling = self.config.sampling if model_name == "SegResNet": - if do_sampling: - size = self.config.sample_size - else: - size = check + size = self.config.sample_size if do_sampling else check logger.info(f"Size of image : {size}") model = model_class.get_net( input_image_size=utils.get_padding_dim(size), @@ -1066,10 +1066,7 @@ def train(self): # dropout_prob=0.3, ) elif model_name == "SwinUNetR": - if do_sampling: - size = self.sample_size - else: - size = check + size = self.sample_size if do_sampling else check logger.info(f"Size of image : {size}") model = model_class.get_net( img_size=utils.get_padding_dim(size), diff --git a/napari_cellseg3d/code_models/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py index 4cdc0a43..73913ab8 100644 --- a/napari_cellseg3d/code_models/models/unet/buildingblocks.py +++ b/napari_cellseg3d/code_models/models/unet/buildingblocks.py @@ -64,10 +64,7 @@ def create_conv( ) elif char == "g": is_before_conv = i < order.index("c") - if is_before_conv: - num_channels = in_channels - else: - num_channels = out_channels + num_channels = in_channels if is_before_conv else out_channels # use only one group if the given number of groups is greater than the number of channels if num_channels < num_groups: diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index 6cc76be6..9591d054 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -1,12 +1,10 @@ import torch.nn as nn from napari_cellseg3d.code_models.models.unet.buildingblocks import ( + DoubleConv, create_decoders, -) -from napari_cellseg3d.code_models.models.unet.buildingblocks import ( create_encoders, ) -from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): @@ -66,7 +64,7 @@ def __init__( f_maps, num_levels=num_levels ) - assert isinstance(f_maps, list) or isinstance(f_maps, tuple) + assert isinstance(f_maps, (list, tuple)) assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" # create encoder path diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 5191e66f..0a613ee7 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -6,8 +6,7 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QTabWidget -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QTabWidget, QWidget # local from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index ed1a43df..6c8370c1 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -4,15 +4,16 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import threshold -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceWidgets, + clear_small_objects, + threshold, + to_semantic, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -166,18 +167,19 @@ def _start(self): f"isotropic_{layer.name}", ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -261,18 +263,19 @@ def _start(self): show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -342,18 +345,19 @@ def _start(self): show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): @@ -428,18 +432,19 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.instance_widgets.run_method(imread(file)) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.instance_widgets.run_method(imread(file)) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -522,18 +527,19 @@ def _start(self): show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index cb149b52..9830d51e 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -233,10 +233,7 @@ def quicksave(self): def _check_ready(self): if self.image_layer_loader.layer_data() is not None: if self.crop_second_image: - if self.label_layer_loader.layer_data() is not None: - return True - else: - return False + return self.label_layer_loader.layer_data() is not None return True return False diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 083b269b..f8ac18ef 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -4,10 +4,8 @@ # Qt from qtpy.QtCore import QSize -from qtpy.QtGui import QIcon -from qtpy.QtGui import QPixmap -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtGui import QIcon, QPixmap +from qtpy.QtWidgets import QVBoxLayout, QWidget # local from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index b2356526..114025f6 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -5,8 +5,7 @@ FigureCanvasQTAgg as FigureCanvas, ) from matplotlib.figure import Figure -from monai.transforms import SpatialPad -from monai.transforms import ToTensor +from monai.transforms import SpatialPad, ToTensor from tifffile import imread from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 3684296e..22867343 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -6,14 +6,17 @@ import pandas as pd # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceMethod, + InstanceWidgets, +) +from napari_cellseg3d.code_models.model_workers import ( + InferenceResult, + InferenceWorker, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -276,9 +279,11 @@ def check_ready(self): if self.layer_choice.isChecked(): if self.image_layer_loader.layer_data() is not None: return True - elif self.folder_choice.isChecked(): - if self.image_filewidget.check_ready(): - return True + elif ( + self.folder_choice.isChecked() + and self.image_filewidget.check_ready() + ): + return True return False def _toggle_display_model_input_size(self): diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index de54b345..cf8e4b85 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -14,23 +14,26 @@ from matplotlib.figure import Figure # MONAI -from monai.losses import DiceCELoss -from monai.losses import DiceFocalLoss -from monai.losses import DiceLoss -from monai.losses import FocalLoss -from monai.losses import GeneralizedDiceLoss -from monai.losses import TverskyLoss +from monai.losses import ( + DiceCELoss, + DiceFocalLoss, + DiceLoss, + FocalLoss, + GeneralizedDiceLoss, + TverskyLoss, +) # Qt from qtpy.QtWidgets import QSizePolicy # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import TrainingReport -from napari_cellseg3d.code_models.model_workers import TrainingWorker +from napari_cellseg3d.code_models.model_workers import ( + TrainingReport, + TrainingWorker, +) NUMBER_TABS = 3 DEFAULT_PATCH_SIZE = 64 diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index a803dfd7..7ed6c549 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -11,14 +11,12 @@ from matplotlib.figure import Figure # Qt -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QLineEdit, QSizePolicy from tifffile import imwrite # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager @@ -182,11 +180,10 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None: - if cfg.image.shape != cfg.labels.shape: - warnings.warn( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: + warnings.warn( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index 8a25d6a6..c09c376f 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,14 +1,12 @@ import warnings -from datetime import datetime -from datetime import timedelta +from datetime import datetime, timedelta from pathlib import Path import napari import pandas as pd # Qt -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QVBoxLayout, QWidget from napari_cellseg3d import interface as ui from napari_cellseg3d import utils @@ -216,10 +214,7 @@ def create_csv(self, label_dir, model_type, filename=None): ) else: # print(self.image_dims[0]) - if self.filename is not None: - filename = self.filename - else: - filename = "image" + filename = self.filename if self.filename is not None else "image" labels = [str(filename) for i in range(self.image_dims[0])] df = pd.DataFrame( diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index c962717e..5463a4ff 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,17 +2,17 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget # local import napari_cellseg3d.interface as ui -from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils -from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils -from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils +from napari_cellseg3d.code_plugins.plugin_convert import ( + AnisoUtils, + RemoveSmallUtils, + ThresholdUtils, + ToInstanceUtils, + ToSemanticUtils, +) from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { @@ -82,7 +82,7 @@ def _update_visibility(self): # print("vis. updated") # print(self.utils_widgets) self._hide_all() - for i, w in enumerate(self.utils_widgets): + for _i, w in enumerate(self.utils_widgets): if isinstance(w, widget_class): w.setVisible(True) w.adjustSize() diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index ab3dba39..737b53aa 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -2,8 +2,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import List -from typing import Optional +from typing import List, Optional import napari import numpy as np @@ -87,9 +86,7 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info( - "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) - ) + logger.info("Model list :\n" + str(f"{name}\n" for name in MODEL_LIST)) return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 90048a60..3f95e1a8 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -4,8 +4,7 @@ import numpy as np import scipy.ndimage as ndimage from skimage.filters import threshold_otsu -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from napari_cellseg3d.code_models.model_instance_seg import binary_watershed diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 479a07dd..641de627 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,8 +2,7 @@ import os import numpy as np -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" # output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index aacf08f8..168990e1 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -8,8 +8,7 @@ import numpy as np import scipy.ndimage as ndimage from napari.qt.threading import thread_worker -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels @@ -227,10 +226,7 @@ def relabel( print("these labels will be added") if test: viewer.close() - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer + viewer = napari.view_image(image) if viewer is None else viewer if not test: viewer.add_labels(artefact_copy, name="labels added") napari.run() diff --git a/napari_cellseg3d/dev_scripts/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py index adfb7914..cdd02256 100644 --- a/napari_cellseg3d/dev_scripts/drafts.py +++ b/napari_cellseg3d/dev_scripts/drafts.py @@ -1,8 +1,7 @@ import napari import numpy as np from magicgui import magicgui -from napari.types import ImageData -from napari.types import LabelsData +from napari.types import ImageData, LabelsData @magicgui(call_button="Run Threshold") diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 998645cb..20668125 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -3,13 +3,15 @@ import napari import numpy as np from napari.qt.threading import thread_worker -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QProgressBar -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import ( + QGridLayout, + QLabel, + QProgressBar, + QPushButton, + QTextEdit, + QVBoxLayout, + QWidget, +) @thread_worker diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index ff3af55c..276f9214 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,40 +1,37 @@ import threading import warnings from functools import partial -from typing import List -from typing import Optional +from typing import List, Optional import napari # Qt # from qtpy.QtCore import QtWarningMsg from qtpy import QtCore -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt -from qtpy.QtCore import QUrl -from qtpy.QtGui import QCursor -from qtpy.QtGui import QDesktopServices -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QCheckBox -from qtpy.QtWidgets import QComboBox -from qtpy.QtWidgets import QDoubleSpinBox -from qtpy.QtWidgets import QFileDialog -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QGroupBox -from qtpy.QtWidgets import QHBoxLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QMenu -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QRadioButton -from qtpy.QtWidgets import QScrollArea -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QSlider -from qtpy.QtWidgets import QSpinBox -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtCore import QObject, Qt, QUrl +from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLayout, + QLineEdit, + QMenu, + QPushButton, + QRadioButton, + QScrollArea, + QSizePolicy, + QSlider, + QSpinBox, + QTextEdit, + QVBoxLayout, + QWidget, +) # Local from napari_cellseg3d import utils @@ -188,7 +185,7 @@ def show_utils_menu(self, widget, event): menu.setStyleSheet(f"background-color: {napari_grey}; color: white;") actions = [] - for title in UTILITIES_WIDGETS.keys(): + for title in UTILITIES_WIDGETS: a = menu.addAction(f"Utilities : {title}") actions.append(a) @@ -773,7 +770,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: warnings.warn("Please select a valid layer !") - return + return None return self._viewer.layers[self.layer_name()].data @@ -1011,7 +1008,7 @@ def make_n_spinboxes( raise ValueError("Cannot make less than 2 spin boxes") boxes = [] - for i in range(n): + for _i in range(n): box = class_(min, max, default, step, parent, fixed) boxes.append(box) return boxes diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 1ddbe67d..a52c3de9 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -136,9 +136,8 @@ def align_array_sizes(array_shape, target_shape): for i in range(len(target_shape)): if target_shape[i] != array_shape[i]: for j in range(len(array_shape)): - if array_shape[i] == target_shape[j]: - if j != i: - index_differences.append({"origin": i, "target": j}) + if array_shape[i] == target_shape[j] and j != i: + index_differences.append({"origin": i, "target": j}) # print(index_differences) if len(index_differences) == 0: @@ -353,9 +352,10 @@ def fill_list_in_between(lst, n, elem): new_list += temp_list else: new_list.append(lst[i]) - for j in range(n): + for _j in range(n): new_list.append(elem) return new_list + return None # def check_zarr(project_path, ext): diff --git a/pyproject.toml b/pyproject.toml index 803338e4..9c5adda7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,47 @@ where = ["."] "*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] [tool.ruff] -# Never enforce `E501` (line length violations). -ignore = ["E501", "E741"] +select = [ + "E", "F", "W", + "A", + "B", + "G", + "I", + "PT", + "PTH", + "RET", + "SIM", + "TCH", + "NPY", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) +# and 'G004' (do not use f-strings in logging) +ignore = ["E501", "E741", "G004", "A003"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", +] [tool.black] line-length = 79 From f11155c6e2144cd4a3198cdce320f8b260e7a7a8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:11:22 +0100 Subject: [PATCH 196/577] Update setup.cfg --- setup.cfg | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/setup.cfg b/setup.cfg index 3a0bdaae..ede7724d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,30 @@ python_requires = >=3.8 package_dir = =. +# add your package requirements here +# the long list after monai is due to monai optional requirements... Not sure how to know in advance which readers it wil use +install_requires = + numpy + napari[all]>=0.4.14 + QtPy + opencv-python>=4.5.5 + dask-image>=0.6.0 + scikit-image>=0.19.2 + matplotlib>=3.4.1 + tifffile>=2022.2.9 + imageio-ffmpeg>=0.4.5 + torch>=1.11 + monai[nibabel,einops]>=0.9.0 + itk + tqdm + nibabel + pyclesperanto-prototype + scikit-image + pillow + tqdm + matplotlib + vispy>=0.9.6 + [options.packages.find] where = . From d131cdf378caf4d355309145a852275fa0a27b15 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 197/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../dev_scripts/artefact_labeling.py | 12 +- .../dev_scripts/correct_labels.py | 28 +- .../dev_scripts/evaluate_labels.py | 303 ++++++++++++++++-- notebooks/assess_instance.ipynb | 281 +++++++++++++--- setup.cfg | 2 +- 5 files changed, 524 insertions(+), 102 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 3f95e1a8..102a7d35 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,10 +1,9 @@ import os import napari -import numpy as np -import scipy.ndimage as ndimage -from skimage.filters import threshold_otsu -from tifffile import imread, imwrite + +# import sys +# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed @@ -105,7 +104,6 @@ def make_labels( image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( np.max(image_contrasted) - np.min(image_contrasted) ) - image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) @@ -126,7 +124,9 @@ def make_labels( ) -def select_image_by_labels(image, labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 168990e1..c888378c 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -10,12 +10,13 @@ from napari.qt.threading import thread_worker from tifffile import imread, imwrite from tqdm import tqdm - -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import threading # import sys # sys.path.append(str(Path(__file__) / "../../")) + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -138,13 +139,7 @@ def ask_labels(unique_artefact, test=False): def relabel( - image_path, - label_path, - go_fast=False, - check_for_unicity=True, - delay=0.3, - viewer=None, - test=False, + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 ): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters @@ -175,10 +170,9 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - if not test: - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -205,11 +199,7 @@ def relabel( artefact_copy = np.where( np.isin(artefact, i_labels_to_add), 0, artefact ) - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer - viewer.add_image(image, name="image") + viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") if not test: diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index bd2f0768..f75ed6bd 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,19 +1,269 @@ import napari import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict +import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels def evaluate_model_performance( - labels, - model_labels, - threshold_correct=PERCENT_CORRECT, - print_details=False, - visualize=False, + labels, model_labels, do_print=False, visualize=False ): """Evaluate the model performance. Parameters @@ -95,36 +345,35 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info( - f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" - ) - - if print_details: - log.info(f"Neurons found: {neurons_found}") - log.info(f"Neurons fused: {neurons_fused}") - log.info(f"Neurons not found: {neurons_not_found}") - log.info(f"Artefacts found: {artefacts_found}") + if do_print: + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) log.info( - f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + "Mean true positive ratio of the model: ", ) + log.info(mean_true_positive_ratio_model) log.info( - f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" + "Mean ratio of the neurons pixels correctly labelled: ", ) + log.info(mean_ratio_pixel_found) log.info( - f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" + "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", ) + log.info(mean_ratio_pixel_found_fused) log.info( - f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" + "Mean true positive ratio of the model for fused neurons: ", ) + log.info(mean_true_positive_ratio_model_fused) log.info( - f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -141,7 +390,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) is False, + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0, ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b8810301..fa22c7b7 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -19,7 +19,6 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", - " to_semantic,\n", ")" ] }, @@ -45,14 +44,12 @@ }, "outputs": [ { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] } ], "source": [ @@ -69,7 +66,9 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")" + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" ] }, { @@ -198,24 +197,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -250,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -303,14 +299,138 @@ " 1.0)" ] }, - "execution_count": 10, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected, name=\"connected\")\n", + "connected.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=2, rem_seed_thres=1\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -318,7 +438,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -332,24 +452,24 @@ "(25, 64, 64)" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 2)\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -363,7 +483,7 @@ "dtype('int64')" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -374,35 +494,101 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels_resized, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { @@ -414,24 +600,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" + "2023-03-22 14:47:30,755 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" ] }, { @@ -444,18 +627,18 @@ { "data": { "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" ] }, - "execution_count": 15, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -466,7 +649,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { diff --git a/setup.cfg b/setup.cfg index ede7724d..78cc98ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=0.9.0 + monai[nibabel,einops]>=1.0.1 itk tqdm nibabel From 85f50ab5a2c38a6cd1ef31bfed1264986c2367ea Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 198/577] Enfore pre-commit style --- .gitignore | 4 - .../_tests/test_plugin_inference.py | 1 - .../code_models/model_instance_seg.py | 9 +- .../code_plugins/plugin_model_inference.py | 12 +- .../code_plugins/plugin_utilities.py | 4 +- napari_cellseg3d/config.py | 1 + .../dev_scripts/artefact_labeling.py | 38 +- .../dev_scripts/correct_labels.py | 12 +- .../dev_scripts/evaluate_labels.py | 471 +----------------- notebooks/assess_instance.ipynb | 50 +- 10 files changed, 88 insertions(+), 514 deletions(-) diff --git a/.gitignore b/.gitignore index df43b4fa..e86beea4 100644 --- a/.gitignore +++ b/.gitignore @@ -106,7 +106,3 @@ notebooks/full_plot.html *.png *.prof -#include test data -!napari_cellseg3d/_tests/res/test.tif -!napari_cellseg3d/_tests/res/test.png -!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..e15958e6 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,7 +7,6 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 60f8bbda..40a07893 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -7,6 +7,7 @@ from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed +from tifffile import imread # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -550,16 +551,14 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug( - f"Caught runtime error {e}, most likely during testing" - ) + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets: + for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 22867343..302a52c9 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,14 +9,10 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - InstanceMethod, - InstanceWidgets, -) -from napari_cellseg3d.code_models.model_workers import ( - InferenceResult, - InferenceWorker, -) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 5463a4ff..fdcad6d3 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,7 +2,9 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # local import napari_cellseg3d.interface as ui diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 737b53aa..6df82043 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -9,6 +9,7 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 102a7d35..04e288d8 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -412,22 +412,22 @@ def create_artefact_labels_from_folder( ) -# if __name__ == "__main__": -# repo_path = Path(__file__).resolve().parents[1] -# print(f"REPO PATH : {repo_path}") -# paths = [ -# "dataset_clean/cropped_visual/train", -# "dataset_clean/cropped_visual/val", -# "dataset_clean/somatomotor", -# "dataset_clean/visual_tif", -# ] -# for data_path in paths: -# path = str(repo_path / data_path) -# print(path) -# create_artefact_labels_from_folder( -# path, -# do_visualize=False, -# threshold_artefact_brightness_percent=20, -# threshold_artefact_size_percent=1, -# contrast_power=20, -# ) +if __name__ == "__main__": + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index c888378c..77835007 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -351,9 +351,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") -# -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index f75ed6bd..3eb62764 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -9,261 +9,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -345,15 +99,21 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") log.info( "Mean true positive ratio of the model: ", ) @@ -390,7 +150,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -723,193 +483,6 @@ def save_as_csv(results, path): # # return label_report_list, new_labels, map_fused_neurons -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index fa22c7b7..b2382c31 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -44,12 +44,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -197,21 +199,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -600,21 +605,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -627,15 +635,15 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 13, From 60f1249f5f33a10e7fb25b75cde44050be21ab3e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 199/577] Updated project files --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9c5adda7..c94c4ee4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ dev = [ "ruff", "tuna", "pre-commit", + ] docs = [ "sphinx", @@ -115,3 +116,4 @@ test = [ "tox", "twine", ] + From c9ce334aa957bd7a9fbaf4168ed67644230e38f3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 200/577] Removing dask-image --- .gitignore | 2 ++ napari_cellseg3d/dev_scripts/convert.py | 3 ++- napari_cellseg3d/utils.py | 13 ++++++------- notebooks/full_plot.ipynb | 1 - setup.cfg | 1 + 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index e86beea4..f8547d92 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,5 @@ notebooks/full_plot.html *.png *.prof + +*.prof diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 641de627..479a07dd 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,7 +2,8 @@ import os import numpy as np -from tifffile import imread, imwrite +from tifffile import imread +from tifffile import imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" # output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index a52c3de9..ecb6a199 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -4,6 +4,8 @@ from pathlib import Path import numpy as np + +# from dask import delayed from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -455,13 +457,10 @@ def load_images( raise ValueError("If loading as a folder, filetype must be specified") if as_folder: - try: - images_original = tfl_imread(filename_pattern_original) - except ValueError: - LOGGER.error( - "Loading a stack this way is no longer supported. Use napari to load a stack." - ) - + raise NotImplementedError( + "Loading as folder not implemented yet. Use napari to load as folder" + ) + # images_original = dask_imread(filename_pattern_original) else: images_original = tfl_imread( filename_pattern_original diff --git a/notebooks/full_plot.ipynb b/notebooks/full_plot.ipynb index 5c640e1b..87f973f9 100644 --- a/notebooks/full_plot.ipynb +++ b/notebooks/full_plot.ipynb @@ -10,7 +10,6 @@ "import matplotlib.pyplot as plt\n", "import os\n", "import numpy as np\n", - "from PIL import Image\n", "from tifffile import imread" ] }, diff --git a/setup.cfg b/setup.cfg index 78cc98ce..6111ed7e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,6 +7,7 @@ package_dir = # add your package requirements here # the long list after monai is due to monai optional requirements... Not sure how to know in advance which readers it wil use +# FIXME remove dask install_requires = numpy napari[all]>=0.4.14 From adc351b256ba727dcef1b97df2a2ed714abf9627 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 201/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f9fe2853..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,14 +5,11 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - id: check-yaml - - id: check-added-large-files - - id: check-toml -# - repo: https://github.com/pycqa/isort -# rev: 5.12.0 -# hooks: -# - id: isort -# args: ["--profile", "black", --line-length=79] + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From 77d140141d2d5b2a8db856fbc9a13cb85144db6a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 202/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- .../_tests/test_plugin_inference.py | 1 + .../code_models/model_instance_seg.py | 187 ++++++++++-------- napari_cellseg3d/code_models/model_workers.py | 21 +- .../code_plugins/plugin_convert.py | 35 ++-- .../code_plugins/plugin_model_inference.py | 14 +- napari_cellseg3d/config.py | 17 +- napari_cellseg3d/interface.py | 10 +- requirements.txt | 4 +- 8 files changed, 149 insertions(+), 140 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index e15958e6..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,6 +7,7 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 40a07893..9e6877da 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,18 +4,21 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label, regionprops +from skimage.measure import label +from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from tifffile import imread +from skimage.filters import thresholding +from skimage.transform import resize # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between +from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from napari_cellseg3d.utils import sphericity_volume_area @@ -80,6 +83,42 @@ def run_method(self, image): raise NotImplementedError("Must be defined in child classes") +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -120,32 +159,27 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - # remove_small_size: float, + remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant - Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation + remove_small_size (float): remove all objects smaller than the specified size in pixels Returns: Instance segmentation labels from Voronoi-Otsu method - """ - # remove_small_size (float): remove all objects smaller than the specified size in pixels - # semantic = np.squeeze(volume) - logger.debug( - f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" - ) + semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( - volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma + semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return np.array(instance) + return instance def binary_connected( @@ -381,16 +415,13 @@ def fill(lst, n=len(properties) - 1): ) -class Watershed(InstanceMethod): - """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - - def __init__(self, widget_parent=None): +class Watershed(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=WATERSHED, + name="Watershed", function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -420,23 +451,20 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( image, - self.sliders[0].slider_value, - self.sliders[1].slider_value, + self.sliders[0].value(), + self.sliders[1].value(), self.counters[0].value(), self.counters[1].value(), ) -class ConnectedComponents(InstanceMethod): - """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - - def __init__(self, widget_parent=None): +class ConnectedComponents(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=CONNECTED_COMP, + name="Connected Components", function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -454,56 +482,44 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( - image, self.sliders[0].slider_value, self.counters[0].value() + image, self.sliders[0].value(), self.counters[0].value() ) -class VoronoiOtsu(InstanceMethod): - """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - - def __init__(self, widget_parent=None): +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=VORONOI_OTSU, + name="Voronoi-Otsu", function=voronoi_otsu, num_sliders=0, - num_counters=2, - widget_parent=widget_parent, + num_counters=3, ) - self.counters[0].label.setText("Spot sigma") # closeness + self.counters[0].label.setText("Spot sigma") self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") # smoothness + self.counters[1].label.setText("Outline sigma") self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" self.counters[1].setMaximum(100) self.counters[1].setValue(2) - # self.counters[2].label.setText("Small object removal") - # self.counters[2].tooltips = ( - # "Volume/size threshold for small object removal." - # "\nAll objects with a volume/size below this value will be removed." - # ) - # self.counters[2].setValue(30) + self.counters[2].label.setText("Small object removal") + self.counters[2].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) def run_method(self, image): - ################ - # For debugging - # import napari - # view = napari.Viewer() - # view.add_image(image) - # napari.run() - ################ - return self.function( image, self.counters[0].value(), self.counters[1].value(), - # self.counters[2].value(), + self.counters[2].value(), ) @@ -518,70 +534,67 @@ def __init__(self, parent=None): Args: parent: parent widget - """ super().__init__(parent) + self.method_choice = ui.DropdownMenu( - list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) + INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = {} - """Contains the instance of the method, with its name as key""" + self.methods = [] self.instance_widgets = {} - """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() def _build(self): + group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - try: - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - method_class = method(widget_parent=self.parent()) - self.methods[name] = method_class - self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets ? - if len(method_class.sliders) > 0: - for slider in method_class.sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method_class.counters) > 0: - for counter in method_class.counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + self.instance_widgets[name] = [] + if len(method().sliders) > 0: + for slider in method().sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method().counters) > 0: + for counter in method().counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): - if name != self.method_choice.currentText(): - for widget in self.instance_widgets[name]: + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + + for widget in self.instance_widgets[method.name]: + widget.set_visibility(True) + + for key in self.instance_widgets.keys(): + if key != method.name: + for widget in self.instance_widgets[key]: widget.set_visibility(False) - else: - for widget in self.instance_widgets[name]: - widget.set_visibility(True) def run_method(self, volume): """ Calls instance function with chosen parameters - Args: volume: image data to run method on Returns: processed image from self._method - """ - method = self.methods[self.method_choice.currentText()] + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() return method.run_method(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { - VORONOI_OTSU: VoronoiOtsu, - WATERSHED: Watershed, - CONNECTED_COMP: ConnectedComponents, + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, } diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 30d37bbd..14449854 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -49,12 +49,8 @@ from tqdm import tqdm # local -from napari_cellseg3d import config, utils -from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_instance_seg import ( - ImageStats, - volume_stats, -) +from napari_cellseg3d.code_models.model_instance_seg import ImageStats +from napari_cellseg3d.code_models.model_instance_seg import volume_stats logger = utils.LOGGER @@ -448,10 +444,11 @@ def model_output( ): inputs = inputs.to("cpu") - # def model_output(inputs): - # return post_process_transforms( - # self.config.model_info.get_model().get_output(model, inputs) - # ) + model_output = lambda inputs: post_process_transforms( + self.config.model_info.get_model().get_output( + model, inputs + ) # TODO(cyril) refactor those functions + ) def model_output(inputs): return post_process_transforms( @@ -600,8 +597,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6c8370c1..6c0bc936 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -8,12 +8,10 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - InstanceWidgets, - clear_small_objects, - threshold, - to_semantic, -) +from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import threshold +from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -345,19 +343,18 @@ def _start(self): show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 302a52c9..f9cac5f3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -13,6 +13,11 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -552,12 +557,9 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[ - self.instance_widgets.method_choice.currentText() - ], - ) + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] self.post_process_config = config.PostProcessConfig( zoom=zoom_config, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 6df82043..a9e3b44f 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,9 +7,6 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -17,6 +14,12 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -114,17 +117,11 @@ class Zoom: zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: InstanceMethod = None - - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 276f9214..8f4f2cdd 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -6,6 +6,10 @@ import napari # Qt +from qtpy import QtCore +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt + # from qtpy.QtCore import QtWarningMsg from qtpy import QtCore from qtpy.QtCore import QObject, Qt, QUrl @@ -1046,11 +1050,11 @@ def __init__( self.label = make_label(name=label) self.valueChanged.connect(self._update_step) - def _update_step(self): # FIXME check divide_factor + def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.01) - else: self.setSingleStep(0.1) + else: + self.setSingleStep(1) @property def tooltips(self): diff --git a/requirements.txt b/requirements.txt index 3189e9c4..9c7126eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pre-commit -pyclesperanto-prototype>=0.22.0 -pysqlite3 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 93551b71610d20ed69ec358d18eadd31aee85bd3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 203/577] isort --- .../_tests/test_weight_download.py | 7 ++- .../code_models/model_instance_seg.py | 4 +- .../code_plugins/plugin_convert.py | 2 +- .../code_plugins/plugin_model_inference.py | 8 ++- napari_cellseg3d/config.py | 11 ++-- napari_cellseg3d/interface.py | 53 +++++++++---------- 6 files changed, 40 insertions(+), 45 deletions(-) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index d8886a56..bffe422b 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,7 +1,6 @@ -from napari_cellseg3d.code_models.model_workers import ( - WEIGHTS_DIR, - WeightsDownloader, -) +from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import WeightsDownloader + # DISABLED, causes GitHub actions to freeze diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 9e6877da..376cf56f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,11 +4,11 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget +from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.filters import thresholding from skimage.transform import resize # from skimage.measure import mesh_surface_area @@ -17,8 +17,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6c0bc936..e4d7480b 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -9,9 +9,9 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index f9cac5f3..1da36989 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,15 +9,13 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index a9e3b44f..957946da 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,6 +7,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -14,12 +19,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 8f4f2cdd..d3cd4e84 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -7,35 +7,34 @@ # Qt from qtpy import QtCore -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt # from qtpy.QtCore import QtWarningMsg -from qtpy import QtCore -from qtpy.QtCore import QObject, Qt, QUrl -from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor -from qtpy.QtWidgets import ( - QCheckBox, - QComboBox, - QDoubleSpinBox, - QFileDialog, - QGridLayout, - QGroupBox, - QHBoxLayout, - QLabel, - QLayout, - QLineEdit, - QMenu, - QPushButton, - QRadioButton, - QScrollArea, - QSizePolicy, - QSlider, - QSpinBox, - QTextEdit, - QVBoxLayout, - QWidget, -) +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt +from qtpy.QtCore import QUrl +from qtpy.QtGui import QCursor +from qtpy.QtGui import QDesktopServices +from qtpy.QtGui import QTextCursor +from qtpy.QtWidgets import QCheckBox +from qtpy.QtWidgets import QComboBox +from qtpy.QtWidgets import QDoubleSpinBox +from qtpy.QtWidgets import QFileDialog +from qtpy.QtWidgets import QGridLayout +from qtpy.QtWidgets import QGroupBox +from qtpy.QtWidgets import QHBoxLayout +from qtpy.QtWidgets import QLabel +from qtpy.QtWidgets import QLayout +from qtpy.QtWidgets import QLineEdit +from qtpy.QtWidgets import QMenu +from qtpy.QtWidgets import QPushButton +from qtpy.QtWidgets import QRadioButton +from qtpy.QtWidgets import QScrollArea +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QSlider +from qtpy.QtWidgets import QSpinBox +from qtpy.QtWidgets import QTextEdit +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # Local from napari_cellseg3d import utils From 75a8bf10e5e6a37a8aa0bfc7d0a3d5572cebc4da Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 204/577] Fix inference --- .../code_models/model_instance_seg.py | 32 +- napari_cellseg3d/code_models/model_workers.py | 8 +- .../code_plugins/plugin_model_inference.py | 11 +- napari_cellseg3d/config.py | 6 +- notebooks/assess_instance.ipynb | 670 +----------------- 5 files changed, 46 insertions(+), 681 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 376cf56f..c0d246b1 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -540,8 +540,10 @@ def __init__(self, parent=None): self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = [] + self.methods = {} + """Contains the instance of the method, with its name as key""" self.instance_widgets = {} + """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() @@ -551,17 +553,23 @@ def _build(self): group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - self.instance_widgets[name] = [] - if len(method().sliders) > 0: - for slider in method().sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method().counters) > 0: - for counter in method().counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) + try: + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class + self.instance_widgets[name] = [] + # moderately unsafe way to init those widgets + if len(method_class.sliders) > 0: + for slider in method_class.sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method_class.counters) > 0: + for counter in method_class.counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + except RuntimeError as e: + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 14449854..6003b0ae 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -542,9 +542,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes( - instance_labels, 0, 2 - ) # TODO(cyril) check if correct + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -597,8 +595,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 1da36989..ff173b43 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -555,9 +555,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -725,9 +726,7 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + method_name = self.worker_config.post_process_config.instance.method.name number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 957946da..84ba4215 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -115,12 +115,16 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b2382c31..40412282 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,691 +4,47 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "tags": [] + "collapsed": true }, "outputs": [], "source": [ - "import napari\n", "import numpy as np\n", - "from pathlib import Path\n", "from tifffile import imread\n", - "\n", - "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", - "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import (\n", - " binary_connected,\n", - " binary_watershed,\n", - " voronoi_otsu,\n", - ")" + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [] - }, + "execution_count": null, "outputs": [], - "source": [ - "viewer = napari.Viewer()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "source": [], "metadata": { "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "pycharm": { + "name": "#%%\n" } - ], - "source": [ - "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", - "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", - "\n", - "prediction = imread(prediction_path)\n", - "gt_labels = imread(gt_labels_path)\n", - "\n", - "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", - "gt_labels_resized = resize(gt_labels, zoom)\n", - "\n", - "\n", - "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "0.5817600487210719" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from napari_cellseg3d.utils import dice_coeff\n", - "\n", - "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", - "\n", - "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n", - "125\n" - ] - } - ], - "source": [ - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)\n", - "print(np.unique(gt_labels_resized).shape[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "connected = binary_connected(prediction_resized, thres_small=2)\n", - "viewer.add_labels(connected, name=\"connected\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, connected)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(68,\n", - " 43,\n", - " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, connected)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", - ")\n", - "viewer.add_labels(watershed)\n", - "eval.evaluate_model_performance(gt_labels_resized, watershed)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(25, 64, 64)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "\n", - "from skimage.morphology import remove_small_objects\n", - "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "dtype('int64')" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gt_labels_resized.dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(voronoi, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(gt_labels_resized, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] + } } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" + "pygments_lexer": "ipython2", + "version": "2.7.6" } }, "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat_minor": 0 +} \ No newline at end of file From aa81b8170ac7106974aa9f1fc4005f5748fc3738 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 205/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/model_instance_seg.py | 6 +- .../dev_scripts/artefact_labeling.py | 94 ++-- .../dev_scripts/correct_labels.py | 99 ++--- .../dev_scripts/evaluate_labels.py | 409 ++++-------------- notebooks/assess_instance.ipynb | 401 ++++++++++++++++- 5 files changed, 551 insertions(+), 458 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index c0d246b1..cd101b35 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -165,14 +165,15 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - remove_small_size (float): remove all objects smaller than the specified size in pixels Returns: Instance segmentation labels from Voronoi-Otsu method + """ semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( @@ -534,6 +535,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) @@ -590,10 +592,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = INSTANCE_SEGMENTATION_METHOD_LIST[ self.method_choice.currentText() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 04e288d8..875ca9b6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,15 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +from pathlib import Path +import scipy.ndimage as ndimage import os - import napari - # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed - -# import sys -# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from skimage.filters import threshold_otsu """ New code by Yves Paychere @@ -43,9 +44,7 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append( - np.array([i, unique[np.argmax(counts)]]) - ) + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -62,7 +61,7 @@ def map_labels(labels, artefacts): def make_labels( - image, + path_image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -74,8 +73,8 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - image : str - image array + path_image : str + Path to image. path_labels_out : str Path of the output labelled image. threshold_size : int, optional @@ -94,25 +93,21 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(image) + image = imread(path_image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( - np.max(image_contrasted) - np.min(image_contrasted) - ) + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size( - labels, min_size=threshold_size, is_labeled=True - ) + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -124,29 +119,26 @@ def make_labels( ) -def select_image_by_labels( - path_image, path_labels, path_image_out, label_values -): +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - image : np.array - image. - labels : np.array - labels. + path_image : str + Path to image. + path_labels : str + Path to labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - # image = imread(image) - # labels = imread(labels) - + image = imread(path_image) + labels = imread(path_labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smallest cube that contains all the non-zero pixels of a 3d image +# select the smalles cube that contains all the none zero pixel of an 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -164,15 +156,16 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(image, path_image_out): +def crop_image_path(path_image, path_image_out): """Crop image. Parameters ---------- - image : np.array - image + path_image : str + Path to image. path_image_out : str Path of the output image. """ + image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -220,9 +213,7 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile( - image[neurons], threshold_artefact_brightness_percent - ) + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -253,9 +244,7 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile( - sizes, threshold_artefact_size_percent - ) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -305,8 +294,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image, - labels, + image_path, + labels_path, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -315,10 +304,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image : np.array - image for artefact detection. - labels : np.array - label image array with each neurons labelled as a different int value. + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -328,6 +317,9 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ + image = imread(image_path) + labels = imread(labels_path) + artefacts = make_artefact_labels( image, labels, @@ -347,12 +339,11 @@ def visualize_images(paths): Parameters ---------- paths : list - List of images to visualize. + List of paths to images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - image = imread(path) - viewer.add_image(image) + viewer.add_image(imread(path), name=os.path.basename(path)) # wait for the user to close the viewer napari.run() @@ -379,12 +370,8 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [ - f for f in os.listdir(path + "/labels") if f.endswith(".tif") - ] - path_images = [ - f for f in os.listdir(path + "/volumes") if f.endswith(".tif") - ] + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] # sort the list path_labels.sort() path_images.sort() @@ -413,6 +400,7 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": + repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 77835007..f94327e2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,22 +1,19 @@ -import threading -import time -import warnings -from functools import partial -from pathlib import Path - -import napari import numpy as np +from tifffile import imread +from tifffile import imwrite import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings from napari.qt.threading import thread_worker -from tifffile import imread, imwrite from tqdm import tqdm import threading - # import sys # sys.path.append(str(Path(__file__) / "../../")) +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels - """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -36,9 +33,7 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm( - range(len(unique_label)), desc="relabeling", ncols=100 - ): + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): i = unique_label[i_label] if i == 0: continue @@ -86,16 +81,13 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact, test=False): +def ask_labels(unique_artefact): global returns returns = [] - if not test: - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] - else: - i_labels_to_add_tmp = [0] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] if i_labels_to_add_tmp == [0]: print("no label added") @@ -138,9 +130,7 @@ def ask_labels(unique_artefact, test=False): print("close the napari window to continue") -def relabel( - image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 -): +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -154,8 +144,6 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 - viewer : napari.Viewer, optional - the napari viewer, by default None """ global returns @@ -170,9 +158,7 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -192,45 +178,30 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread( - target=partial(ask_labels, test=test), args=(unique_artefact,) - ) + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add), 0, artefact - ) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - if not test: - napari.run() + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add_tmp), artefact, 0 - ) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) print("these labels will be added") - if test: - viewer.close() - viewer = napari.view_image(image) if viewer is None else viewer - if not test: - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") - if test: - revert = "n" - viewer.close() + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) - if test: - break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -287,16 +258,12 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget( - old_label, new_label, map_labels_existing, delay=0.5 -): +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect( - lambda arg: modify_viewer(old_label, new_label, arg) - ) + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -313,12 +280,8 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] - ) - new_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] - ) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -327,9 +290,7 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget( - old_label, new_label, map_labels_existing, delay=delay - ) + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) napari.run() @@ -346,12 +307,12 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, - str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) ) if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3eb62764..857bcd19 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,24 +1,74 @@ -import napari import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance( - labels, - model_labels, - threshold_correct=PERCENT_CORRECT, - print_details=False, - visualize=False, -): +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): """Evaluate the model performance. Parameters ---------- @@ -26,10 +76,8 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - print_details : bool + do_print : bool If True, print the results. - visualize : bool - If True, visualize the results. Returns ------- neuron_found : float @@ -53,7 +101,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels, threshold_correct + labels, model_labels ) # calculate the number of neurons individually found @@ -71,9 +119,7 @@ def evaluate_model_performance( artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean( - [i[3] for i in map_labels_existing] - ) + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -82,9 +128,7 @@ def evaluate_model_performance( if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean( - [i[2] for i in map_fused_neurons] - ) + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -99,42 +143,27 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info( - f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" - ) - - if print_details: - log.info(f"Neurons found: {neurons_found}") - log.info(f"Neurons fused: {neurons_fused}") - log.info(f"Neurons not found: {neurons_not_found}") - log.info(f"Artefacts found: {artefacts_found}") - log.info( - "Mean true positive ratio of the model: ", - ) - log.info(mean_true_positive_ratio_model) - log.info( + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, ) - log.info(mean_ratio_pixel_found) - log.info( + print( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, ) - log.info(mean_ratio_pixel_found_fused) - log.info( + print( "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, ) - log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact ) - log.info(mean_ratio_false_pixel_artefact) - if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -150,21 +179,15 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) is False, - unique_labels, - 0, + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where( - np.isin(labels, neurones_not_found_labels), labels, 0 - ) + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), - model_labels, - 0, + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -192,81 +215,6 @@ def evaluate_model_performance( ) -def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > threshold_correct: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > threshold_correct: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels - - def save_as_csv(results, path): """ Save the results as a csv file @@ -278,7 +226,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - log.debug(np.array(results).shape) + print(np.array(results).shape) df = pd.DataFrame( [results], columns=[ @@ -296,193 +244,6 @@ def save_as_csv(results, path): df.to_csv(path, index=False) -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From 680a21333c4917ed338e0c5dd02f9df6e5d88ec3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 206/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../code_models/model_instance_seg.py | 2 +- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 ++- .../dev_scripts/evaluate_labels.py | 282 ++++++++++++++++-- notebooks/assess_instance.ipynb | 239 +++++++++++---- 5 files changed, 494 insertions(+), 107 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index cd101b35..8b7e234b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -180,7 +180,7 @@ def voronoi_otsu( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return instance + return np.array(instance) def binary_connected( diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,14 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): + +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -21,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -28,15 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -44,14 +228,19 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,16 +248,22 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -78,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -119,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -128,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -144,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -179,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -226,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..86ef4e29 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -111,26 +106,25 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,112 - Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -178,7 +172,8 @@ ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected, name=\"connected\")\n", + "connected.dtype" ] }, { @@ -191,26 +186,25 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" ] }, { @@ -253,26 +247,25 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" ] }, { @@ -302,7 +295,9 @@ } ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] @@ -330,6 +325,10 @@ ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] @@ -337,6 +336,33 @@ { "cell_type": "code", "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -346,28 +372,94 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { @@ -375,12 +467,51 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + "\n" ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -389,7 +520,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { From 24ab925b714e5ea9e0cc037d22197f7915d7743c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 207/577] Added pre-commit hooks --- .pre-commit-config.yaml | 44 +++++++++++++++++++++++++++++------------ requirements.txt | 4 +++- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..802dfe20 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,44 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: +# - repo: https://github.com/pre-commit/pre-commit-hooks +# rev: v4.0.1 +# hooks: # - id: check-docstring-first - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] +# - id: end-of-file-fixer +# - id: trailing-whitespace +# - repo: https://github.com/asottile/setup-cfg-fmt +# rev: v1.20.0 +# hooks: +# - id: setup-cfg-fmt +# - repo: https://github.com/PyCQA/flake8 +# rev: 4.0.1 +# hooks: +# - id: flake8 +# additional_dependencies: [flake8-typing-imports>=1.9.0] +# - repo: https://github.com/myint/autoflake +# rev: v1.4 +# hooks: +# - id: autoflake +# args: ["--in-place", "--remove-all-unused-imports"] +# - repo: https://github.com/PyCQA/isort +# rev: 5.10.1 +# hooks: +# - id: isort - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.262' + rev: 'v0.0.257' hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 22.3.0 hooks: - id: black - args: [--line-length=79] + args: [--line-length=88] +# - repo: https://github.com/asottile/pyupgrade +# rev: v2.29.1 +# hooks: +# - id: pyupgrade +# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks rev: v0.3.0 hooks: diff --git a/requirements.txt b/requirements.txt index 9c7126eb..3189e9c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pre-commit +pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 3b49f969c0f8ec6c969ae407532f626b7375ca13 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:40:31 +0100 Subject: [PATCH 208/577] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 802dfe20..d1e22fb1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: rev: 22.3.0 hooks: - id: black - args: [--line-length=88] + args: [--line-length=79] # - repo: https://github.com/asottile/pyupgrade # rev: v2.29.1 # hooks: From 5b8972cf9bc17d1491b2f1e2173a196433c90bac Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:48:32 +0100 Subject: [PATCH 209/577] Update pyproject.toml --- pyproject.toml | 52 ++------------------------------------------------ 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c94c4ee4..83aa1ebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,57 +32,9 @@ dependencies = [ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" -[tool.setuptools] -include-package-data = true - -[tool.setuptools.packages.find] -where = ["."] - -[tool.setuptools.package-data] -"*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] - [tool.ruff] -select = [ - "E", "F", "W", - "A", - "B", - "G", - "I", - "PT", - "PTH", - "RET", - "SIM", - "TCH", - "NPY", -] -# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) -# and 'G004' (do not use f-strings in logging) -ignore = ["E501", "E741", "G004", "A003"] -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".mypy_cache", - ".nox", - ".pants.d", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "venv", - "docs/conf.py", - "napari_cellseg3d/_tests/conftest.py", -] +# Never enforce `E501` (line length violations). +ignore = ["E501"] [tool.black] line-length = 79 From dc66d49120775a13babd004f9e9f9d0639399abc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:50:33 +0100 Subject: [PATCH 210/577] Update pyproject.toml Ruff config --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 83aa1ebb..462263e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ build-backend = "setuptools.build_meta" [tool.ruff] # Never enforce `E501` (line length violations). -ignore = ["E501"] +ignore = ["E501", "E741"] [tool.black] line-length = 79 From 67dbabc34bf0ad8e5cdda02257d8cd33dd3acf0e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 211/577] Enfore pre-commit style --- .gitignore | 5 +- .../code_models/model_instance_seg.py | 19 +- napari_cellseg3d/code_models/model_workers.py | 9 +- .../code_plugins/plugin_convert.py | 5 +- .../code_plugins/plugin_model_inference.py | 3 - napari_cellseg3d/config.py | 7 +- .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 471 ++++++++---------- notebooks/assess_instance.ipynb | 158 +++--- 10 files changed, 313 insertions(+), 366 deletions(-) diff --git a/.gitignore b/.gitignore index f8547d92..427603f1 100644 --- a/.gitignore +++ b/.gitignore @@ -104,7 +104,4 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png -*.prof - - -*.prof +notebooks/instance_test.ipynb diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 8b7e234b..45a20b3d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,12 +4,10 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.transform import resize # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -570,23 +568,16 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() - - for widget in self.instance_widgets[method.name]: - widget.set_visibility(True) - - for key in self.instance_widgets.keys(): - if key != method.name: - for widget in self.instance_widgets[key]: + for name in self.instance_widgets.keys(): + if name != self.method_choice.currentText(): + for widget in self.instance_widgets[name]: widget.set_visibility(False) def run_method(self, volume): diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 6003b0ae..5456c730 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -444,11 +444,10 @@ def model_output( ): inputs = inputs.to("cpu") - model_output = lambda inputs: post_process_transforms( - self.config.model_info.get_model().get_output( - model, inputs - ) # TODO(cyril) refactor those functions - ) + # def model_output(inputs): + # return post_process_transforms( + # self.config.model_info.get_model().get_output(model, inputs) + # ) def model_output(inputs): return post_process_transforms( diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index e4d7480b..3346d2b8 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -4,7 +4,8 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread +from tifffile import imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils @@ -143,7 +144,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + self.results_path.mkdir(exist_ok=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index ff173b43..971f81bd 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,9 +9,6 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 84ba4215..3ae070e2 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,10 +7,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -89,7 +86,9 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info("Model list :\n" + str(f"{name}\n" for name in MODEL_LIST)) + logger.info( + "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) + ) return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..3c5be52a 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -8,261 +8,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -344,15 +98,21 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") log.info( "Mean true positive ratio of the model: ", ) @@ -389,7 +149,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -460,6 +220,193 @@ def save_as_csv(results, path): df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons + # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 86ef4e29..609da8b3 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -50,12 +50,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -84,9 +86,36 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -110,28 +139,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "125\n" ] }, { @@ -162,7 +172,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -171,9 +181,8 @@ } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized, thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { @@ -190,21 +199,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -217,14 +229,14 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", " 1.0)" ] }, @@ -251,21 +263,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -278,15 +293,15 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 8, @@ -471,21 +486,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -498,15 +516,15 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 13, From 78583a68128e313e717baeb80bb264da3b6eadc1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 212/577] Update .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 427603f1..ee1bf4a0 100644 --- a/.gitignore +++ b/.gitignore @@ -104,4 +104,4 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png -notebooks/instance_test.ipynb + From c3677b0e902df1e51ea2ca356abf39c53cd2d027 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 213/577] Version bump --- napari_cellseg3d/__init__.py | 2 +- .../code_plugins/plugin_helper.py | 2 +- setup.cfg | 32 +++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 11e8de0e..2c537225 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc6" +__version__ = "0.0.2rc2" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index f8ac18ef..a20a2c61 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -37,7 +37,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc6'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 6111ed7e..17ef734e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,35 @@ +[metadata] +name = napari-cellseg3d +version = 0.0.2rc2 +author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis +author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu + +license = MIT +description = plugin for cell segmentation +long_description = file: README.md +long_description_content_type = text/markdown +classifiers = + Development Status :: 2 - Pre-Alpha + Intended Audience :: Science/Research + Framework :: napari + Topic :: Software Development :: Testing + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Operating System :: OS Independent + License :: OSI Approved :: MIT License + Topic :: Scientific/Engineering :: Artificial Intelligence + Topic :: Scientific/Engineering :: Image Processing + Topic :: Scientific/Engineering :: Visualization + +url = https://github.com/AdaptiveMotorControlLab/CellSeg3d +project_urls = + Bug Tracker = https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues + Documentation = https://adaptivemotorcontrollab.github.io/cellseg3d-docs/res/welcome.html + Source Code = https://github.com/AdaptiveMotorControlLab/CellSeg3d + [options] packages = find: include_package_data = True From 0b6999630d7cbd95885a77acc8fa4459598cffd6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:33:40 +0200 Subject: [PATCH 214/577] Revert "Version bump" This reverts commit 6e39971b39fb926084f3ed71d82e8c25f68f8b6f. --- napari_cellseg3d/__init__.py | 2 +- napari_cellseg3d/code_plugins/plugin_helper.py | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 2c537225..6e2681e8 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc2" +__version__ = "0.0.2rc1" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index a20a2c61..a3fd8c0d 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -37,7 +37,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 17ef734e..5789f74f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc2 +version = 0.0.2rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 35b1ad0ae3ace81739dbb2106461b2dda27bbb70 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 215/577] Updated project files --- pyproject.toml | 22 +++++++++++++--------- setup.cfg | 8 +++++++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 462263e7..5dec250c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,15 +9,17 @@ authors = [ requires-python = ">=3.8" dependencies = [ "numpy", - "napari>=0.4.14", + "napari[all]>=0.4.14", "QtPy", "opencv-python>=4.5.5", + "dask-image>=0.6.0", "scikit-image>=0.19.2", "matplotlib>=3.4.1", "tifffile>=2022.2.9", "imageio-ffmpeg>=0.4.5", "torch>=1.11", "monai[nibabel,einops]>=0.9.0", + "itk", "tqdm", "nibabel", "scikit-image", @@ -32,6 +34,15 @@ dependencies = [ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] + +[tool.setuptools.package-data] +"*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] + [tool.ruff] # Never enforce `E501` (line length violations). ignore = ["E501", "E741"] @@ -44,16 +55,10 @@ profile = "black" line_length = 79 [project.optional-dependencies] -all = [ - "napari[all]>=0.4.14", -] dev = [ "isort", "black", "ruff", - "tuna", - "pre-commit", - ] docs = [ "sphinx", @@ -67,5 +72,4 @@ test = [ "coverage", "tox", "twine", -] - +] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 5789f74f..2420dd1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc1 +version = 0.0.2rc6 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu @@ -65,6 +65,12 @@ install_requires = [options.packages.find] where = . +[options.package_data] +napari-cellseg3d = + res/*.png + code_models/models/pretrained/*.json + napari.yaml + [options.entry_points] napari.manifest = napari-cellseg3d = napari_cellseg3d:napari.yaml From 8b0067c016656f11f1e38ddcabd1ee071bd3d197 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 216/577] Fixed wrong value in instance sliders --- .../code_models/model_instance_seg.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 45a20b3d..5eb987f6 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -174,6 +174,9 @@ def voronoi_otsu( """ semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -450,8 +453,8 @@ def __init__(self): def run_method(self, image): return self.function( image, - self.sliders[0].value(), - self.sliders[1].value(), + self.sliders[0].slider_value, + self.sliders[1].slider_value, self.counters[0].value(), self.counters[1].value(), ) @@ -481,7 +484,7 @@ def __init__(self): def run_method(self, image): return self.function( - image, self.sliders[0].value(), self.counters[0].value() + image, self.sliders[0].slider_value, self.counters[0].value() ) @@ -538,7 +541,7 @@ def __init__(self, parent=None): super().__init__(parent) self.method_choice = ui.DropdownMenu( - INSTANCE_SEGMENTATION_METHOD_LIST.keys() + list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) self.methods = {} """Contains the instance of the method, with its name as key""" @@ -558,7 +561,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -568,8 +571,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() @@ -590,9 +595,7 @@ def run_method(self, volume): Returns: processed image from self._method """ - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() + method = self.methods[self.method_choice.currentText()] return method.run_method(volume) From a29ff5a439fe7e4d055f08833e317368a95a10a0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 217/577] Removing dask-image --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ee1bf4a0..0ec12b01 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ notebooks/full_plot.html *.csv *.png +*.prof From f68241c643a49bf248dded6cf04f209c3dbca8dd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 218/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 0abcf387..24f4e867 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,12 +1,9 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np -from napari_cellseg3d.code_plugins.plugin_utilities import ( - UTILITIES_WIDGETS, - Utilities, -) +from napari_cellseg3d.code_plugins.plugin_utilities import Utilities +from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS def test_utils_plugin(make_napari_viewer): @@ -24,9 +21,4 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) - if utils_name == "Convert to instance labels": - # to avoid issues with Voronoi-Otsu missing runtime - menu = widget.utils_widgets[i].instance_widgets.method_choice - menu.setCurrentIndex(menu.currentIndex() + 1) - widget.utils_widgets[i]._start() From e3a33d1865f0b98fade02d92c6189ecd278619d7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 219/577] Relabeling tests --- .gitignore | 6 +- .../_tests/test_labels_correction.py | 3 +- .../dev_scripts/artefact_labeling.py | 93 +++++++++---------- .../dev_scripts/correct_labels.py | 75 ++++++++++----- 4 files changed, 102 insertions(+), 75 deletions(-) diff --git a/.gitignore b/.gitignore index 0ec12b01..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -104,5 +104,9 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png - *.prof + +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..9d4e7801 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,7 +1,6 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 9a344545..bf724a46 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,7 +1,5 @@ import numpy as np -from tifffile import imread -from tifffile import imwrite -from pathlib import Path +from tifffile import imwrite, imread import scipy.ndimage as ndimage import os import napari @@ -64,7 +62,7 @@ def map_labels(labels, artefacts): def make_labels( - path_image, + image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -76,7 +74,7 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - path_image : str + image : str Path to image. path_labels_out : str Path of the output labelled image. @@ -96,7 +94,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(path_image) + # image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -126,28 +124,26 @@ def make_labels( ) -def select_image_by_labels( - path_image, path_labels, path_image_out, label_values -): +def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - path_image : str - Path to image. - path_labels : str - Path to labels. + image : np.array + image. + labels : np.array + labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - image = imread(path_image) - labels = imread(path_labels) + # image = imread(image) + # labels = imread(labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smalles cube that contains all the none zero pixel of an 3d image +# select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -165,16 +161,15 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(path_image, path_image_out): +def crop_image_path(image, path_image_out): """Crop image. Parameters ---------- - path_image : str - Path to image. + image : np.array + image path_image_out : str Path of the output image. """ - image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -307,8 +302,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image_path, - labels_path, + image, + labels, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -317,10 +312,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image_path : str - Path to image file. - labels_path : str - Path to label image file with each neurons labelled as a different value. + image : np.array + image for artefact detection. + labels : np.array + label image array with each neurons labelled as a different int value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -330,9 +325,6 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ - image = imread(image_path) - labels = imread(labels_path) - artefacts = make_artefact_labels( image, labels, @@ -352,11 +344,12 @@ def visualize_images(paths): Parameters ---------- paths : list - List of paths to images to visualize. + List of images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - viewer.add_image(imread(path), name=os.path.basename(path)) + image = imread(path) + viewer.add_image(image) # wait for the user to close the viewer napari.run() @@ -416,22 +409,22 @@ def create_artefact_labels_from_folder( ) -if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] - print(f"REPO PATH : {repo_path}") - paths = [ - "dataset_clean/cropped_visual/train", - "dataset_clean/cropped_visual/val", - "dataset_clean/somatomotor", - "dataset_clean/visual_tif", - ] - for data_path in paths: - path = str(repo_path / data_path) - print(path) - create_artefact_labels_from_folder( - path, - do_visualize=False, - threshold_artefact_brightness_percent=20, - threshold_artefact_size_percent=1, - contrast_power=20, - ) +# if __name__ == "__main__": +# repo_path = Path(__file__).resolve().parents[1] +# print(f"REPO PATH : {repo_path}") +# paths = [ +# "dataset_clean/cropped_visual/train", +# "dataset_clean/cropped_visual/val", +# "dataset_clean/somatomotor", +# "dataset_clean/visual_tif", +# ] +# for data_path in paths: +# path = str(repo_path / data_path) +# print(path) +# create_artefact_labels_from_folder( +# path, +# do_visualize=False, +# threshold_artefact_brightness_percent=20, +# threshold_artefact_size_percent=1, +# contrast_power=20, +# ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index cd09754e..50f2e47a 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -4,6 +4,7 @@ import scipy.ndimage as ndimage import napari from pathlib import Path +from functools import partial import time import warnings from napari.qt.threading import thread_worker @@ -85,13 +86,16 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact): +def ask_labels(unique_artefact, test=False): global returns returns = [] - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + if not test: + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + else: + i_labels_to_add_tmp = [0] if i_labels_to_add_tmp == [0]: print("no label added") @@ -135,7 +139,13 @@ def ask_labels(unique_artefact): def relabel( - image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 + image_path, + label_path, + go_fast=False, + check_for_unicity=True, + delay=0.3, + viewer=None, + test=False, ): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters @@ -150,6 +160,8 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 + viewer : napari.Viewer, optional + the napari viewer, by default None """ global returns @@ -164,9 +176,10 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + if not test: + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -186,15 +199,22 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t = threading.Thread( + target=partial(ask_labels, test=test), args=(unique_artefact,) + ) t.start() artefact_copy = np.where( np.isin(artefact, i_labels_to_add), 0, artefact ) - viewer = napari.view_image(image) + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + viewer.add_image(image, name="image") viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - napari.run() + if not test: + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones @@ -205,15 +225,26 @@ def relabel( np.isin(artefact, i_labels_to_add_tmp), artefact, 0 ) print("these labels will be added") - viewer = napari.view_image(image) - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") + if test: + viewer.close() + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + if not test: + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if test: + revert = "n" + viewer.close() if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) + if test: + break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -334,9 +365,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") - image_path = str(im_path / "image.tif") - gt_labels_path = str(im_path / "labels.tif") - - relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# if __name__ == "__main__": +# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") +# image_path = str(im_path / "image.tif") +# gt_labels_path = str(im_path / "labels.tif") +# +# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) From 23edaa3696af354dd7418543e51487c76575552e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:06:43 +0200 Subject: [PATCH 220/577] Added new pre-commit hooks --- .pre-commit-config.yaml | 43 ++++++++++++----------------------------- pyproject.toml | 3 ++- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1e22fb1..da16a3b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,44 +1,25 @@ repos: -# - repo: https://github.com/pre-commit/pre-commit-hooks -# rev: v4.0.1 -# hooks: -# - id: check-docstring-first -# - id: end-of-file-fixer -# - id: trailing-whitespace -# - repo: https://github.com/asottile/setup-cfg-fmt -# rev: v1.20.0 -# hooks: -# - id: setup-cfg-fmt -# - repo: https://github.com/PyCQA/flake8 -# rev: 4.0.1 -# hooks: -# - id: flake8 -# additional_dependencies: [flake8-typing-imports>=1.9.0] -# - repo: https://github.com/myint/autoflake -# rev: v1.4 -# hooks: -# - id: autoflake -# args: ["--in-place", "--remove-all-unused-imports"] -# - repo: https://github.com/PyCQA/isort -# rev: 5.10.1 -# hooks: -# - id: isort + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-docstring-first + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.257' + rev: 'v0.0.262' hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.3.0 hooks: - id: black args: [--line-length=79] -# - repo: https://github.com/asottile/pyupgrade -# rev: v2.29.1 -# hooks: -# - id: pyupgrade -# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks rev: v0.3.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 5dec250c..d2a2adbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "isort", "black", "ruff", + "pre-commit", ] docs = [ "sphinx", @@ -72,4 +73,4 @@ test = [ "coverage", "tox", "twine", -] \ No newline at end of file +] From 6a18d6a02e26b78a515a17c4e2615c795ba29886 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 221/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da16a3b9..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,13 +2,14 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: check-docstring-first +# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From c8140c2830714c0f5ff524746dbbf1e723d544c4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 222/577] Run full suite of pre-commit hooks --- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/model_instance_seg.py | 10 ++++++++- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 3 +-- 6 files changed, 34 insertions(+), 20 deletions(-) diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 24f4e867..7908e8b4 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 5eb987f6..e33d1d0f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,7 +15,7 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -517,6 +517,14 @@ def __init__(self): ) def run_method(self, image): + ################ + # For debugging + # import napari + # view = napari.Viewer() + # view.add_image(image) + # napari.run() + ################ + return self.function( image, self.counters[0].value(), diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3c5be52a..26b45d3f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,10 +1,9 @@ +import napari import numpy as np from collections import Counter from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict -import napari from napari_cellseg3d.utils import LOGGER as log From b146f14c19ffaf3803dfa257518c32e24de79f81 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 24 Mar 2023 17:08:44 +0100 Subject: [PATCH 223/577] Model class refactor --- docs/res/guides/custom_model_template.rst | 24 -- .../_tests/test_weight_download.py | 4 +- napari_cellseg3d/code_models/model_workers.py | 367 ++++++++++-------- .../code_models/models/model_SegResNet.py | 48 ++- .../code_models/models/model_SwinUNetR.py | 36 +- .../code_models/models/model_TRAILMAP.py | 39 +- .../code_models/models/model_TRAILMAP_MS.py | 27 +- .../code_models/models/model_VNet.py | 56 +-- .../code_models/models/model_test.py | 24 +- .../code_plugins/plugin_model_inference.py | 141 ++++--- .../code_plugins/plugin_model_training.py | 4 +- .../code_plugins/plugin_review.py | 2 +- napari_cellseg3d/config.py | 65 +++- napari_cellseg3d/interface.py | 18 +- napari_cellseg3d/utils.py | 18 +- notebooks/assess_instance.ipynb | 121 +++--- requirements.txt | 6 +- setup.cfg | 2 +- 18 files changed, 568 insertions(+), 434 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index afbcd98a..9bad49b0 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -10,28 +10,4 @@ To add a custom model, you will need a **.py** file with the following structure :: - def get_net(): - return ModelClass # should return the class of the model, - # for example SegResNet or UNET - - def get_weights_file(): - return "weights_file.pth" # name of the weights file for the model, - # which should be in *napari_cellseg3d/models/pretrained* - - - def get_output(model, input): - out = model(input) # should return the model's output as [C, N, D,H,W] - # (C: channel, N, batch size, D,H,W : depth, height, width) - return out - - - def get_validation(model, val_inputs): - val_outputs = model(val_inputs) # should return the proper type for validation - # with sliding_window_inference from MONAI - return val_outputs - - - def ModelClass(x1,x2...): - # your Pytorch model here... - return results # should return as [C, N, D,H,W] diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index bffe422b..1bcb40d7 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR from napari_cellseg3d.code_models.model_workers import WeightsDownloader @@ -7,6 +7,6 @@ def test_weight_download(): downloader = WeightsDownloader() downloader.download_weights("test", "test.pth") - result_path = WEIGHTS_DIR / "test.pth" + result_path = PRETRAINED_WEIGHTS_DIR / "test.pth" assert result_path.is_file() diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 5456c730..f5bb798c 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from math import ceil from pathlib import Path -from typing import List, Optional +import typing as t import numpy as np import torch @@ -41,7 +41,10 @@ from monai.utils import set_determinism # threads -from napari.qt.threading import GeneratorWorker, WorkerBaseSignals +from napari.qt.threading import GeneratorWorker + +# from napari.qt.threading import thread_worker +from napari.qt.threading import WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -64,14 +67,16 @@ # https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ # https://napari-staging-site.github.io/guides/stable/threading.html -WEIGHTS_DIR = Path(__file__).parent.resolve() / Path("models/pretrained") -logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {WEIGHTS_DIR}") +PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( + "models/pretrained" +) +logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") class WeightsDownloader: """A utility class the downloads the weights of a model when needed.""" - def __init__(self, log_widget: Optional[ui.Log] = None): + def __init__(self, log_widget: t.Optional[ui.Log] = None): """ Creates a WeightsDownloader, optionally with a log widget to display the progress. @@ -93,11 +98,11 @@ def download_weights(self, model_name: str, model_weights_filename: str): import tarfile import urllib.request - def show_progress(count, block_size, total_size): + def show_progress(_, block_size, __): # count, block_size, total_size pbar.update(block_size) logger.info("*" * 20) - pretrained_folder_path = WEIGHTS_DIR + pretrained_folder_path = PRETRAINED_WEIGHTS_DIR json_path = pretrained_folder_path / Path("pretrained_model_urls.json") check_path = pretrained_folder_path / Path(model_weights_filename) @@ -167,12 +172,17 @@ def safe_extract( class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_""" # TODO link ? + Separate from Worker instances as indicated `here`_ + + .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + """ # TODO link ? log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" warn_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" + error_signal = Signal(Exception, str) + """qtpy.QtCore.Signal: signal to be sent when some error should be emitted in main thread""" # Should not be an instance variable but a class variable, not defined in __init__, see # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect @@ -203,33 +213,24 @@ def __init__( ): """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function. - Args: - * config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * device: cuda or cpu device to use for torch - - * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance + The config contains the following attributes: + * device: cuda or cpu device to use for torch + * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance + * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom + * results_path: the path to save the results to + * filetype: the file extension to use when saving, + * transforms: a dict containing transforms to perform at various times. + * instance: a dict containing parameters regarding instance segmentation + * use_window: use window inference with specific size or whole image + * window_infer_size: size of window if use_window is True + * keep_on_cpu: keep images on CPU or no + * stats_csv: compute stats on cells and save them to a csv file + * images_filepaths: the paths to the images of the dataset + * layer: the layer to run inference on - * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom - - * results_path: the path to save the results to - - * filetype: the file extension to use when saving, - - * transforms: a dict containing transforms to perform at various times. - - * instance: a dict containing parameters regarding instance segmentation - - * use_window: use window inference with specific size or whole image - - * window_infer_size: size of window if use_window is True - - * keep_on_cpu: keep images on CPU or no - - * stats_csv: compute stats on cells and save them to a csv file - - * images_filepaths: the paths to the images of the dataset + Args: + * worker_config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * layer: the layer to run inference on Note: See :py:func:`~self.inference` """ @@ -237,6 +238,7 @@ def __init__( self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self.config = worker_config @@ -269,6 +271,21 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Raises an error in main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + + self.log_signal.emit("!" * 20) + self.log_signal.emit("Error occured") + # self.log_signal.emit(msg) + # self.log_signal.emit(str(exception)) + + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + yield exception + # self.quit() + def log_parameters(self): config = self.config @@ -398,7 +415,7 @@ def load_layer(self): ) # for anisotropy to be monai-like, i.e. zyx # FIXME rotation not always correct dims_check = volume.shape - # self.log("\nChecking dimensions...") + self.log("Checking dimensions...") pad = utils.get_padding_dim(dims_check) # logger.debug(volume.shape) @@ -449,54 +466,61 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - def model_output(inputs): - return post_process_transforms( - self.config.model_info.get_model().get_output(model, inputs) - ) - - dataset_device = ( - "cpu" if self.config.keep_on_cpu else self.config.device - ) - - window_size = self.config.sliding_window_config.window_size - window_overlap = self.config.sliding_window_config.window_overlap - - # FIXME - # import sys - - # old_stdout = sys.stdout - # old_stderr = sys.stderr - - # sys.stdout = self.downloader.log_widget - # sys.stdout = self.downloader.log_widget - - outputs = sliding_window_inference( - inputs, - roi_size=[window_size, window_size, window_size], - sw_batch_size=1, # TODO add param - predictor=model_output, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) - - # sys.stdout = old_stdout - # sys.stderr = old_stderr - - out = outputs.detach().cpu() - - if aniso_transform is not None: - out = aniso_transform(out) + if self.config.keep_on_cpu: + dataset_device = "cpu" + else: + dataset_device = self.config.device - if post_process: - out = np.array(out).astype(np.float32) - out = np.squeeze(out) - return out + if self.config.sliding_window_config.is_enabled(): + window_size = self.config.sliding_window_config.window_size + window_size = [window_size, window_size, window_size] + window_overlap = self.config.sliding_window_config.window_overlap else: - return out + window_size = None + window_overlap = 0 + try: + # logger.debug(f"model : {model}") + logger.debug(f"inputs shape : {inputs.shape}") + logger.debug(f"inputs type : {inputs.dtype}") + try: + # outputs = model(inputs) + + def model_output_wrapper(inputs): + result = model(inputs) + return post_process_transforms(result) + + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) + except Exception as e: + logger.error(e, exc_info=True) + logger.debug("failed to run sliding window inference") + self.raise_error(e, "Error during sliding window inference") + logger.debug(f"Inference output shape: {outputs.shape}") + self.log("Post-processing...") + out = outputs.detach().cpu().numpy() + if aniso_transform is not None: + out = aniso_transform(out) + if post_process: + out = np.array(out).astype(np.float32) + out = np.squeeze(out) + return out + else: + return out + except Exception as e: + logger.error(e, exc_info=True) + self.raise_error(e, "Error during sliding window inference") + # sys.stdout = old_stdout + # sys.stderr = old_stderr - def create_result_dict( # FIXME replace with result class + def create_inference_result( self, semantic_labels, instance_labels, @@ -570,7 +594,10 @@ def save_image( + f"_{time}_" + self.config.filetype ) - imwrite(file_path, image) + try: + imwrite(file_path, image) + except ValueError as e: + self.raise_error(e, "Error during image saving") filename = Path(file_path).stem if from_layer: @@ -635,7 +662,7 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): self.log(f"Inference completed on image n°{i+1}") - return self.create_result_dict( + return self.create_inference_result( out, instance_labels, from_layer=False, @@ -646,9 +673,7 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): def stats_csv(self, instance_labels): if self.config.compute_stats: - stats = volume_stats( - instance_labels - ) # TODO test with area mesh function + stats = volume_stats(instance_labels) return stats # except ValueError as e: @@ -674,13 +699,14 @@ def inference_on_layer(self, image, model, post_process_transforms): instance_labels, stats = self.get_instance_result(out, from_layer=True) - return self.create_result_dict( + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels, from_layer=True, stats=stats, ) + # @thread_worker(connect={"errored": self.raise_error}) def inference(self): """ Requires: @@ -723,35 +749,68 @@ def inference(self): try: dims = self.config.model_info.model_input_size - # self.log(f"MODEL DIMS : {dims}") + self.log(f"MODEL DIMS : {dims}") model_name = self.config.model_info.name model_class = self.config.model_info.get_model() - self.log(model_name) + self.log(f"Model name : {model_name}") weights_config = self.config.weights_config post_process_config = self.config.post_process_config - if model_name == "SegResNet": - model = model_class.get_net( - input_image_size=[ - dims, - dims, - dims, - ], # TODO FIX ! find a better way & remove model-specific code + # try: + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[128, 128, 128], + ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, ) - elif model_name == "SwinUNetR": - model = model_class.get_net( - img_size=[dims, dims, dims], - use_checkpoint=False, + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - else: - model = model_class.get_net() - model = model.to(self.config.device) + model.load_state_dict( + torch.load( + weights, + map_location=self.config.device, + ) + ) + self.log("Done") + # except Exception as e: + # self.raise_error(e, "Issue loading weights") + # except Exception as e: + # self.raise_error(e, "Issue instantiating model") + + # if model_name == "SegResNet": + # model = model_class( + # input_image_size=[ + # dims, + # dims, + # dims, + # ], + # ) + # elif model_name == "SwinUNetR": + # model = model_class( + # img_size=[dims, dims, dims], + # use_checkpoint=False, + # ) + # else: + # model = model_class.get_net() self.log_parameters() - model.to(self.config.device) - # load_transforms = Compose( # [ # LoadImaged(keys=["image"]), @@ -772,25 +831,6 @@ def inference(self): AsDiscrete(threshold=t), EnsureType() ) - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path - else: - self.downloader.download_weights( - model_name, - model_class.get_weights_file(), - ) - weights = str( - WEIGHTS_DIR / Path(model_class.get_weights_file()) - ) - model.load_state_dict( - torch.load( - weights, - map_location=self.config.device, - ) - ) - self.log("Done") - is_folder = self.config.images_filepaths is not None is_layer = self.config.layer is not None @@ -815,6 +855,9 @@ def inference(self): else: raise ValueError("No data has been provided. Aborting.") + if model is None: + raise ValueError("Model is None") + model.eval() with torch.no_grad(): ################################ @@ -830,9 +873,10 @@ def inference(self): input_image, model, post_process_transforms ) model.to("cpu") - + # self.quit() except Exception as e: - self.log(f"Error during inference : {e}") + logger.error(e, exc_info=True) + self.raise_error(e, "Inference failed") self.quit() finally: self.quit() @@ -842,10 +886,10 @@ def inference(self): class TrainingReport: show_plot: bool = True epoch: int = 0 - loss_values: List = None - validation_metric: List = None + loss_values: t.Dict = None # TODO(cyril) : change to dict and unpack different losses for e.g. WNet with several losses + validation_metric: t.List = None weights: np.array = None - images: List[np.array] = None + images: t.List[np.array] = None class TrainingWorker(GeneratorWorker): @@ -897,6 +941,7 @@ def __init__( self._signals = LogSignal() self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self._weight_error = False ############################################# @@ -922,6 +967,14 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Sends an error to main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + self.quit() + def log_parameters(self): self.log("-" * 20) self.log("Parameters summary :\n") @@ -1051,23 +1104,14 @@ def train(self): do_sampling = self.config.sampling - if model_name == "SegResNet": - size = self.config.sample_size if do_sampling else check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - input_image_size=utils.get_padding_dim(size), - # out_channels=1, - # dropout_prob=0.3, - ) - elif model_name == "SwinUNetR": - size = self.sample_size if do_sampling else check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - img_size=utils.get_padding_dim(size), - use_checkpoint=True, - ) + if do_sampling: + size = self.config.sample_size else: - model = model_class.get_net() # get an instance of the model + size = check + + model = model_class( # FIXME check if correct + input_img_size=utils.get_padding_dim(size), use_checkpoint=True + ) model = model.to(self.config.device) epoch_loss_values = [] @@ -1207,7 +1251,11 @@ def train(self): else: load_whole_images = Compose( [ - LoadImaged(keys=["image", "label"]), + LoadImaged( + keys=["image", "label"], + # image_only=True, + # reader=WSIReader(backend="tifffile") + ), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="PLI"), SpatialPadd( @@ -1254,9 +1302,9 @@ def train(self): if weights_config.custom: if weights_config.use_pretrained: - weights_file = model_class.get_weights_file() + weights_file = model_class.weights_file self.downloader.download_weights(model_name, weights_file) - weights = WEIGHTS_DIR / Path(weights_file) + weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) weights_config.path = weights else: weights = str(Path(weights_config.path)) @@ -1270,6 +1318,7 @@ def train(self): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") + logger.error(e, exc_info=True) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" @@ -1317,7 +1366,7 @@ def train(self): batch_data["label"].to(device), ) optimizer.zero_grad() - outputs = model_class.get_output(model, inputs) + outputs = model(inputs) # self.log(f"Output dimensions : {outputs.shape}") loss = self.config.loss_function(outputs, labels) loss.backward() @@ -1348,10 +1397,24 @@ def train(self): val_data["image"].to(device), val_data["label"].to(device), ) - - val_outputs = model_class.get_validation( - model, val_inputs + self.log("Performing validation...") + try: + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) + except Exception as e: + self.raise_error(e, "Error during validation") + logger.debug( + f"val_outputs shape : {val_outputs.shape}" ) + # val_outputs = model(val_inputs) pred = decollate_batch(val_outputs) @@ -1398,7 +1461,7 @@ def train(self): weights=model.state_dict(), images=checkpoint_output, ) - + self.log("Validation completed") yield train_report weights_filename = ( @@ -1431,7 +1494,7 @@ def train(self): model.to("cpu") except Exception as e: - self.log(f"Error in training : {e}") + self.raise_error(e, "Error in training") self.quit() finally: self.quit() diff --git a/napari_cellseg3d/code_models/models/model_SegResNet.py b/napari_cellseg3d/code_models/models/model_SegResNet.py index 8856e18d..8b6e6e65 100644 --- a/napari_cellseg3d/code_models/models/model_SegResNet.py +++ b/napari_cellseg3d/code_models/models/model_SegResNet.py @@ -1,21 +1,33 @@ from monai.networks.nets import SegResNetVAE -def get_net(input_image_size, out_channels=1, dropout_prob=0.3): - return SegResNetVAE( - input_image_size, out_channels=out_channels, dropout_prob=dropout_prob - ) - - -def get_weights_file(): - return "SegResNet.pth" - - -def get_output(model, input): - out = model(input)[0] - return out - - -def get_validation(model, val_inputs): - val_outputs = model(val_inputs) - return val_outputs[0] +class SegResNet_(SegResNetVAE): + use_default_training = True + weights_file = "SegResNet.pth" + + def __init__( + self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs + ): + super().__init__( + input_img_size, + out_channels=out_channels, + dropout_prob=dropout_prob, + ) + + def forward(self, x): + res = SegResNetVAE.forward(self, x) + # logger.debug(f"SegResNetVAE.forward: {res[0].shape}") + return res[0] + + def get_model_test(self, size): + return SegResNetVAE( + size, in_channels=1, out_channels=1, dropout_prob=0.3 + ) + + # def get_output(model, input): + # out = model(input)[0] + # return out + + # def get_validation(model, val_inputs): + # val_outputs = model(val_inputs) + # return val_outputs[0] diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 532aeb89..fe4d380c 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,25 +1,23 @@ -import torch from monai.networks.nets import SwinUNETR -def get_weights_file(): - return "Swin64_best_metric.pth" +class SwinUNETR_(SwinUNETR): + use_default_training = True + weights_file = "Swin64_best_metric.pth" + def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs + ) -def get_net(img_size, use_checkpoint=True): - return SwinUNETR( - img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - ) + # def get_output(self, input): + # out = self(input) + # return torch.sigmoid(out) - -def get_output(model, input): - out = model(input) - return torch.sigmoid(out) - - -def get_validation(model, val_inputs): - return model(val_inputs) + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 09de2a26..8a108e37 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -2,28 +2,8 @@ from torch import nn -def get_weights_file(): - # model additionally trained on Mathis/Wyss mesoSPIM data - return "TRAILMAP_PyTorch.pth" - # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them - - -def get_net(): - return TRAILMAP(1, 1) - - -def get_output(model, input): - out = model(input) - - return out - - -def get_validation(model, val_inputs): - return model(val_inputs) - - class TRAILMAP(nn.Module): - def __init__(self, in_ch, out_ch): + def __init__(self, in_ch, out_ch, *args, **kwargs): super().__init__() self.conv0 = self.encoderBlock(in_ch, 32, 3) # input self.conv1 = self.encoderBlock(32, 64, 3) # l1 @@ -112,3 +92,20 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) return out + + +class TRAILMAP_(TRAILMAP): + use_default_training = True + weights_file = "TRAILMAP_PyTorch.pth" # model additionally trained on Mathis/Wyss mesoSPIM data + # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + # def get_output(model, input): + # out = model(input) + # + # return out + + # def get_validation(model, val_inputs): + # return model(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 0fc68d34..e3ca00a6 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,20 +1,21 @@ from napari_cellseg3d.code_models.models.unet.model import UNet3D -def get_weights_file(): - # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) - return "TRAILMAP_MS_best_metric_epoch_26.pth" - - -def get_net(): - return UNet3D(1, 1) +class TRAILMAP_MS_(UNet3D): + use_default_training = True + weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth" + # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) -def get_output(model, input): - out = model(input) - - return out + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + # def get_output(self, input): + # out = self(input) -def get_validation(model, val_inputs): - return model(val_inputs) + # return out + # + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 0c854832..41554e80 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -1,29 +1,33 @@ -from monai.inferers import sliding_window_inference from monai.networks.nets import VNet -def get_net(): - return VNet() - - -def get_weights_file(): - return "VNet_40e.pth" - - -def get_output(model, input): - out = model(input) - return out - - -def get_validation(model, val_inputs): - roi_size = (64, 64, 64) - sw_batch_size = 1 - val_outputs = sliding_window_inference( - val_inputs, - roi_size, - sw_batch_size, - model, - mode="gaussian", - overlap=0.7, - ) - return val_outputs +class VNet_(VNet): + use_default_training = True + weights_file = "VNet_40e.pth" + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError: + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) + + # def get_output(self, input): + # out = self(input) + # return out + + # def get_validation(self, val_inputs): # FIXME standardize + # roi_size = (64, 64, 64) + # sw_batch_size = 1 + # val_outputs = sliding_window_inference( + # val_inputs, + # roi_size, + # sw_batch_size, + # self, + # # mode="gaussian", + # # overlap=0.7, + # ) + # return val_outputs diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 5871c4a7..1ccac3da 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -2,26 +2,22 @@ from torch import nn -def get_weights_file(): - return "test.pth" - - class TestModel(nn.Module): - def __init__(self): + use_default_training = True + weights_file = "test.pth" + + def __init__(self, **kwargs): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) - def get_net(self): - return self - - def get_output(self, _, input): - return input + # def get_output(self, _, input): + # return input - def get_validation(self, val_inputs): - return val_inputs + # def get_validation(self, val_inputs): + # return val_inputs # if __name__ == "__main__": @@ -29,8 +25,8 @@ def get_validation(self, val_inputs): # model = TestModel() # model.train() # model.zero_grad() -# from napari_cellseg3d.config import WEIGHTS_DIR +# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR # torch.save( # model.state_dict(), -# WEIGHTS_DIR + f"/{get_weights_file()}" +# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" # ) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 971f81bd..ab61b590 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -159,6 +159,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" ) + self.window_size_choice.setCurrentIndex(3) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -601,10 +602,13 @@ def start(self): self.worker.set_download_log(self.log) self.worker.started.connect(self.on_start) + self.worker.log_signal.connect(self.log.print_and_log) self.worker.warn_signal.connect(self.log.warn) + self.worker.error_signal.connect(self.log.error) + self.worker.yielded.connect(partial(self.on_yield)) # - self.worker.errored.connect(partial(self.on_yield)) + self.worker.errored.connect(partial(self.on_error)) self.worker.finished.connect(self.on_finish) if self.get_device(show=False) == "cuda": @@ -641,15 +645,18 @@ def on_start(self): self.log.print_and_log(f"Saving results to : {self.results_path}") self.log.print_and_log("Worker is running...") - def on_error(self): - """Catches errors and tries to clean up. TODO : upgrade""" + def on_error(self, error): + """Catches errors and tries to clean up.""" + self.log.print_and_log("!" * 20) self.log.print_and_log("Worker errored...") - self.log.print_and_log("Trying to clean up...") + self.log.error(error) + # self.log.print_and_log("Trying to clean up...") + self.worker.quit() self.btn_start.setText("Start") self.btn_close.setVisible(True) - self.worker = None self.worker_config = None + self.worker = None self.empty_cuda_cache() def on_finish(self): @@ -672,83 +679,91 @@ def on_yield(self, result: InferenceResult): data (dict): dict yielded by :py:func:`~inference()`, contains : "image_id" : index of the returned image, "original" : original volume used for inference, "result" : inference result widget (QWidget): widget for accessing attributes """ + + if isinstance(result, Exception): + self.on_error(result) + # raise result # viewer, progress, show_res, show_res_number, zoon, show_original # check that viewer checkbox is on and that max number of displays has not been reached. # widget.log.print_and_log(result) + try: + image_id = result.image_id + model_name = result.model_name + if self.worker_config.images_filepaths is not None: + total = len(self.worker_config.images_filepaths) + else: + total = 1 - image_id = result.image_id - model_name = result.model_name - if self.worker_config.images_filepaths is not None: - total = len(self.worker_config.images_filepaths) - else: - total = 1 + viewer = self._viewer - viewer = self._viewer + pbar_value = image_id // total + if pbar_value == 0: + pbar_value = 1 - pbar_value = image_id // total - if pbar_value == 0: - pbar_value = 1 + self.progress.setValue(100 * pbar_value) - self.progress.setValue(100 * pbar_value) + if ( + self.config.show_results + and image_id <= self.config.show_results_count + ): + zoom = self.worker_config.post_process_config.zoom.zoom_values - if ( - self.config.show_results - and image_id <= self.config.show_results_count - ): - zoom = self.worker_config.post_process_config.zoom.zoom_values + viewer.dims.ndisplay = 3 + viewer.scale_bar.visible = True - viewer.dims.ndisplay = 3 - viewer.scale_bar.visible = True + if self.config.show_original and result.original is not None: + viewer.add_image( + result.original, + colormap="inferno", + name=f"original_{image_id}", + scale=zoom, + opacity=0.7, + ) + + out_colormap = "twilight" + if self.worker_config.post_process_config.thresholding.enabled: + out_colormap = "turbo" - if self.config.show_original and result.original is not None: viewer.add_image( - result.original, - colormap="inferno", - name=f"original_{image_id}", - scale=zoom, - opacity=0.7, + result.result, + colormap=out_colormap, + name=f"pred_{image_id}_{model_name}", + opacity=0.8, ) - out_colormap = "twilight" - if self.worker_config.post_process_config.thresholding.enabled: - out_colormap = "turbo" - - viewer.add_image( - result.result, - colormap=out_colormap, - name=f"pred_{image_id}_{model_name}", - opacity=0.8, - ) - - if result.instance_labels is not None: - labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + if result.instance_labels is not None: + labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - stats = result.stats + stats = result.stats - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if self.worker_config.compute_stats and stats is not None: + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + self.log.print_and_log( + f"Number of instances : {stats.number_objects}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) + except Exception as e: + self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index cf8e4b85..4f2f7cdf 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -985,7 +985,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e) + logger.error(e, exc_info=True) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1153,7 +1153,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e) + logger.error(e, exc_info=True) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 7ed6c549..e3e05f6c 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -401,7 +401,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e) + logger.error(e, exc_info=True) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 3ae070e2..3d1d6d9e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -10,12 +10,11 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP -from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet -from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR -from napari_cellseg3d.code_models.models import ( - model_TRAILMAP_MS as TRAILMAP_MS, -) -from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ +from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ +from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ +from napari_cellseg3d.code_models.models.model_VNet import VNet_ + from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -24,16 +23,15 @@ # TODO(cyril) add JSON load/save MODEL_LIST = { - "SegResNet": SegResNet, - "VNet": VNet, + "SegResNet": SegResNet_, + "VNet": VNet_, # "TRAILMAP": TRAILMAP, - "TRAILMAP_MS": TRAILMAP_MS, - "SwinUNetR": SwinUNetR, + "TRAILMAP_MS": TRAILMAP_MS_, + "SwinUNetR": SwinUNETR_, # "test" : DO NOT USE, reserved for testing } - -WEIGHTS_DIR = str( +PRETRAINED_WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") ) @@ -69,8 +67,11 @@ class ReviewSession: @dataclass class ModelInfo: - """Dataclass recording model info : - - name (str): name of the model""" + """Dataclass recording model info + Args: + name (str): name of the model + model_input_size (Optional[List[int]]): input size of the model + """ name: str = next(iter(MODEL_LIST)) model_input_size: Optional[List[int]] = None @@ -94,7 +95,7 @@ def get_model_name_list(): @dataclass class WeightsInfo: - path: str = WEIGHTS_DIR + path: str = PRETRAINED_WEIGHTS_DIR custom: bool = False use_pretrained: Optional[bool] = False @@ -121,6 +122,14 @@ class InstanceSegConfig: @dataclass class PostProcessConfig: + """Class to record params for post processing + + Args: + zoom (Zoom): zoom config + thresholding (Thresholding): thresholding config + instance (InstanceSegConfig): instance segmentation config + """ + zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() instance: InstanceSegConfig = InstanceSegConfig() @@ -141,7 +150,15 @@ def is_enabled(self): @dataclass class InfererConfig: - """Class to record params for Inferer plugin""" + """Class to record params for Inferer plugin + + Args: + model_info (ModelInfo): model info + show_results (bool): show results in napari + show_results_count (int): number of results to show + show_original (bool): show original image in napari + anisotropy_resolution (List[int]): anisotropy resolution + """ model_info: ModelInfo = None show_results: bool = False @@ -152,7 +169,21 @@ class InfererConfig: @dataclass class InferenceWorkerConfig: - """Class to record configuration for Inference job""" + """Class to record configuration for Inference job + + Args: + device (str): device to use for inference + model_info (ModelInfo): model info + weights_config (WeightsInfo): weights info + results_path (str): path to save results + filetype (str): filetype to save results + keep_on_cpu (bool): keep results on cpu + compute_stats (bool): compute stats + post_process_config (PostProcessConfig): post processing config + sliding_window_config (SlidingWindowConfig): sliding window config + images_filepaths (str): path to images to infer + layer (napari.layers.Layer): napari layer to infer on + """ device: str = "cpu" model_info: ModelInfo = ModelInfo() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d3cd4e84..57b3b0bd 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -295,6 +295,22 @@ def warn(self, warning): finally: self.lock.release() + def error(self, error, msg=None): + """Show exception and message from another thread""" + self.lock.acquire() + try: + logger.error(error, exc_info=True) + if msg is not None: + self.print_and_log(f"{msg} : {error}", printing=False) + else: + self.print_and_log( + f"Excepetion caught in another thread : {error}", + printing=False, + ) + raise error + finally: + self.lock.release() + ############## # UI elements @@ -1199,7 +1215,7 @@ def open_folder_dialog( logger.info(f"Default : {default_path}") filenames = QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + widget, "Open directory", default_path + "/.." ) return filenames diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index ecb6a199..5683c541 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,10 +2,8 @@ import warnings from datetime import datetime from pathlib import Path - import numpy as np - -# from dask import delayed +from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -38,6 +36,18 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] +# class TiffFileReader(ImageReader): +# def __init__(self): +# super().__init__() +# +# def verify_suffix(self, filename): +# if filename == "tif": +# return True +# def read(self, data, **kwargs): +# return tfl_imread(data) +# +# def get_data(self, data): +# return data, {} def normalize_x(image): """Normalizes the values of an image array to be between [-1;1] rather than [0;255] @@ -122,8 +132,6 @@ def dice_coeff(y_true, y_pred): def resize(image, zoom_factors): - from monai.transforms import Zoom - isotropic_image = Zoom( zoom_factors, keep_size=False, diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 609da8b3..169775f5 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -49,10 +49,20 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -62,14 +72,15 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", + "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = resize(prediction, zoom)\n", + "prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -92,7 +103,7 @@ { "data": { "text/plain": [ - "0.5817600487210719" + "0.7538125057831502" ] }, "execution_count": 4, @@ -103,9 +114,15 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", + "semantic_gt = to_semantic(gt_labels_resized.copy())\n", + "semantic_pred = to_semantic(prediction_resized.copy())\n", + "\n", + "viewer.add_image(semantic_gt, colormap='bop blue')\n", + "viewer.add_image(semantic_pred, colormap='red')\n", + "\n", "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", + " semantic_gt,\n", + " prediction_resized\n", ")" ] }, @@ -172,7 +189,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -199,24 +216,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" + "2023-03-24 14:23:13,590 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", + "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", + "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" ] }, { @@ -229,15 +246,15 @@ { "data": { "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" + "(63,\n", + " 45,\n", + " 16,\n", + " 16,\n", + " 0.819027731148306,\n", + " 0.8401649108992161,\n", + " 0.83609908334452,\n", + " 0.8066092803671974,\n", + " 0.98)" ] }, "execution_count": 7, @@ -263,24 +280,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" + "2023-03-24 14:23:13,732 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", + "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", + "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -293,15 +310,15 @@ { "data": { "text/plain": [ - "(68,\n", - " 43,\n", + "(77,\n", + " 34,\n", " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" + " 9,\n", + " 0.728461197681457,\n", + " 0.8885669859686413,\n", + " 0.8950588507577087,\n", + " 0.7472814623489069,\n", + " 0.878614359974009)" ] }, "execution_count": 8, @@ -339,7 +356,7 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 0.6, outline_sigma=0.7)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", @@ -486,24 +503,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" + "2023-03-24 14:23:14,241 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", + "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", + "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" ] }, { @@ -516,15 +533,15 @@ { "data": { "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" + "(102,\n", + " 8,\n", + " 14,\n", + " 16,\n", + " 0.708505702558253,\n", + " 0.8832633585884945,\n", + " 0.9759871495093808,\n", + " 0.6670483272595948,\n", + " 0.8653680990771155)" ] }, "execution_count": 13, diff --git a/requirements.txt b/requirements.txt index 3189e9c4..3ca0e56d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ black coverage +imageio-ffmpeg>=0.4.5 isort itk pytest @@ -15,13 +16,12 @@ QtPy opencv-python>=4.5.5 pre-commit pyclesperanto-prototype>=0.22.0 -pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 +ruff tifffile>=2022.2.9 -imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,einops]>=1.0.1 +monai[nibabel,einops,tifffile]>=1.0.1 pillow scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index 2420dd1c..f3294b60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=1.0.1 + monai[nibabel,einops,tifffile]>=1.0.1 itk tqdm nibabel From 39efc1524c36d0c6759a754387b0cbb1414ec771 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 29 Mar 2023 09:55:58 +0200 Subject: [PATCH 224/577] Added LR scheduler in training - Added ReduceLROnPlateau with params in training - Updated training guide - Minor UI attribute refactor - black --- docs/res/code/plugin_model_training.rst | 1 - docs/res/guides/training_module_guide.rst | 2 + napari_cellseg3d/_tests/fixtures.py | 3 + .../_tests/test_plugin_inference.py | 2 +- .../code_models/model_framework.py | 2 +- .../code_models/model_instance_seg.py | 8 +-- napari_cellseg3d/code_models/model_workers.py | 11 ++++ napari_cellseg3d/code_plugins/plugin_base.py | 2 +- .../code_plugins/plugin_convert.py | 4 +- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- .../code_plugins/plugin_model_inference.py | 4 +- .../code_plugins/plugin_model_training.py | 62 ++++++++++++------- .../code_plugins/plugin_utilities.py | 2 +- napari_cellseg3d/config.py | 2 + napari_cellseg3d/interface.py | 43 +++++++------ 15 files changed, 93 insertions(+), 57 deletions(-) diff --git a/docs/res/code/plugin_model_training.rst b/docs/res/code/plugin_model_training.rst index 870dfd14..dc1271fc 100644 --- a/docs/res/code/plugin_model_training.rst +++ b/docs/res/code/plugin_model_training.rst @@ -18,6 +18,5 @@ Methods Attributes ********************* - .. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer :members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index fb8992d2..05ce69be 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -74,6 +74,8 @@ The training module is comprised of several tabs. * The **batch size** (larger means quicker training and possibly better performance but increased memory usage) * The **number of epochs** (a possibility is to start with 60 epochs, and decrease or increase depending on performance.) * The **epoch interval** for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.) +* The **schedular patience**, which is the amount of epoch at a plateau that is waited for until the learning rate is reduced +* The **scheduler factor**, which is the factor by which to reduce the learning rate once a plateau is reached * Whether to use deterministic training, and the seed to use. .. note:: diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index b40a77d3..bd6b0ac7 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -14,3 +14,6 @@ def print_and_log(self, text, printing=None): def warn(self, warning): warnings.warn(warning) + + def error(self, e): + raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..66c50fba 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -38,4 +38,4 @@ def test_inference(make_napari_viewer, qtbot): # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: # blocker.connect(widget.worker.errored) - # assert len(viewer.layers) == 2 + #### assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 2cc4265e..d541b486 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -78,7 +78,7 @@ def __init__( # ) self.model_choice = ui.DropdownMenu( - sorted(self.available_models.keys()), label="Model name" + sorted(self.available_models.keys()), text_label="Model name" ) self.weights_filewidget = ui.FilePathWidget( diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e33d1d0f..0c87a2df 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -73,7 +73,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label="", parent=None), + ui.DoubleIncrementCounter(text_label="", parent=None), ) self.counters.append(getattr(self, widget)) @@ -426,13 +426,13 @@ def __init__(self): num_counters=2, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" self.sliders[0].setValue(50) - self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].label.setText("Seed probability threshold") self.sliders[1].tooltips = "Probability threshold for seeding" self.sliders[1].setValue(90) @@ -469,7 +469,7 @@ def __init__(self): num_counters=1, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index f5bb798c..bca24035 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -70,6 +70,7 @@ PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) +VERBOSE_SCHEDULER = True logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") @@ -1292,6 +1293,13 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + mode="min", + factor=self.config.scheduler_factor, + patience=self.config.scheduler_patience, + verbose=VERBOSE_SCHEDULER, + ) dice_metric = DiceMetric(include_background=True, reduction="mean") best_metric = -1 @@ -1384,6 +1392,9 @@ def train(self): epoch_loss_values.append(epoch_loss) self.log(f"Epoch: {epoch + 1}, Average loss: {epoch_loss:.4f}") + self.log("Updating scheduler...") + scheduler.step(epoch_loss) + checkpoint_output = [] if ( diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 0a613ee7..2cb3581b 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -99,7 +99,7 @@ def __init__( ) self.filetype_choice = ui.DropdownMenu( - [".tif", ".tiff"], label="File format" + [".tif", ".tiff"], text_label="File format" ) ######## qInstallMessageHandler(ui.handle_adjust_errors_wrapper(self)) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 3346d2b8..a847ebf7 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -210,7 +210,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): lower=1, upper=100000, default=10, - label="Remove all smaller than (pxs):", + text_label="Remove all smaller than (pxs):", ) self.results_path = Path.home() / Path("cellseg3d/small_removed") @@ -472,7 +472,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): upper=100000.0, step=0.5, default=10.0, - label="Remove all smaller than (value):", + text_label="Remove all smaller than (value):", ) self.results_path = Path.home() / Path("cellseg3d/threshold") diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 9830d51e..1647e858 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -533,7 +533,7 @@ def set_slice( # container_widget.extend(sliders) ui.add_widgets( container_widget.layout, - [ui.combine_blocks(s, s.text_label) for s in sliders], + [ui.combine_blocks(s, s.label) for s in sliders], ) # vw.window.add_dock_widget([spinbox, container_widget], area="right") wdgts = vw.window.add_dock_widget( diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index ab61b590..44e70a76 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -105,7 +105,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ###################### # TODO : better way to handle SegResNet size reqs ? self.model_input_size = ui.IntIncrementCounter( - lower=1, upper=1024, default=128, label="\nModel input size" + lower=1, upper=1024, default=128, text_label="\nModel input size" ) self.model_choice.currentIndexChanged.connect( self._toggle_display_model_input_size @@ -157,7 +157,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): # ) self.window_size_choice = ui.DropdownMenu( - sizes_window, label="Window size" + sizes_window, text_label="Window size" ) self.window_size_choice.setCurrentIndex(3) # set to 64 by default diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 4f2f7cdf..132c9531 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -46,6 +46,8 @@ class Trainer(ModelFramework, metaclass=ui.QWidgetSingleton): Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during training through validation.""" + default_config = config.TrainingWorkerConfig() + def __init__( self, viewer: "napari.viewer.Viewer", @@ -168,14 +170,13 @@ def __init__( ################################ # interface - default = config.TrainingWorkerConfig() self.zip_choice = ui.CheckBox("Compress results") self.validation_percent_choice = ui.Slider( lower=10, upper=90, - default=default.validation_percent * 100, + default=self.default_config.validation_percent * 100, step=5, parent=self, ) @@ -183,12 +184,12 @@ def __init__( self.epoch_choice = ui.IntIncrementCounter( lower=2, upper=200, - default=default.max_epochs, - label="Number of epochs : ", + default=self.default_config.max_epochs, + text_label="Number of epochs : ", ) self.loss_choice = ui.DropdownMenu( - sorted(self.loss_dict.keys()), label="Loss function" + sorted(self.loss_dict.keys()), text_label="Loss function" ) self.lbl_loss_choice = self.loss_choice.label self.loss_choice.setCurrentIndex(0) @@ -196,7 +197,7 @@ def __init__( self.sample_choice_slider = ui.Slider( lower=2, upper=50, - default=default.num_samples, + default=self.default_config.num_samples, text_label="Number of patches per image : ", ) @@ -205,13 +206,13 @@ def __init__( self.batch_choice = ui.Slider( lower=1, upper=10, - default=default.batch_size, + default=self.default_config.batch_size, text_label="Batch size : ", ) self.val_interval_choice = ui.IntIncrementCounter( - default=default.validation_interval, - label="Validation interval : ", + default=self.default_config.validation_interval, + text_label="Validation interval : ", ) self.epoch_choice.valueChanged.connect(self._update_validation_choice) @@ -228,12 +229,24 @@ def __init__( ] self.learning_rate_choice = ui.DropdownMenu( - learning_rate_vals, label="Learning rate" + learning_rate_vals, text_label="Learning rate" ) self.lbl_learning_rate_choice = self.learning_rate_choice.label self.learning_rate_choice.setCurrentIndex(1) + self.scheduler_patience_choice = ui.IntIncrementCounter( + 1, + 99, + default=self.default_config.scheduler_patience, + text_label="Scheduler patience", + ) + self.scheduler_factor_choice = ui.Slider( + divide_factor=100, + default=self.default_config.scheduler_factor * 100, + text_label="Scheduler factor :", + ) + self.augment_choice = ui.CheckBox("Augment data") self.close_buttons = [ @@ -268,7 +281,8 @@ def __init__( "Deterministic training", func=self._toggle_deterministic_param ) self.box_seed = ui.IntIncrementCounter( - upper=10000000, default=default.deterministic_config.seed + upper=10000000, + default=self.default_config.deterministic_config.seed, ) self.lbl_seed = ui.make_label("Seed", self) self.container_seed = ui.combine_blocks( @@ -309,6 +323,12 @@ def set_tooltips(): self.learning_rate_choice.setToolTip( "The learning rate to use in the optimizer. \nUse a lower value if you're using pre-trained weights" ) + self.scheduler_factor_choice.setToolTip( + "The factor by which to reduce the learning rate once the loss reaches a plateau" + ) + self.scheduler_patience_choice.setToolTip( + "The amount of epochs to wait for before reducing the learning rate" + ) self.augment_choice.setToolTip( "Check this to enable data augmentation, which will randomly deform, flip and shift the intensity in images" " to provide a more general dataset. \nUse this if you're extracting more than 10 samples per image" @@ -632,26 +652,20 @@ def _build(self): "Training parameters", r=1, b=5, t=11 ) - spacing = 20 - ui.add_widgets( train_param_group_l, [ self.batch_choice.container, # batch size - ui.combine_blocks( - self.learning_rate_choice, - self.lbl_learning_rate_choice, - min_spacing=spacing, - horizontal=False, - l=5, - t=5, - r=5, - b=5, - ), # learning rate + self.lbl_learning_rate_choice, + self.learning_rate_choice, self.epoch_choice.label, # epochs self.epoch_choice, self.val_interval_choice.label, self.val_interval_choice, # validation interval + self.scheduler_patience_choice.label, + self.scheduler_patience_choice, + self.scheduler_factor_choice.label, + self.scheduler_factor_choice.container, ], None, ) @@ -833,6 +847,8 @@ def start(self): max_epochs=self.epoch_choice.value(), loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.value(), validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index fdcad6d3..45c0c119 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -43,7 +43,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): # self.small = RemoveSmallUtils(self._viewer) self.utils_choice = ui.DropdownMenu( - UTILITIES_WIDGETS.keys(), label="Utilities" + UTILITIES_WIDGETS.keys(), text_label="Utilities" ) self._build() diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 3d1d6d9e..afc16bd3 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -230,6 +230,8 @@ class TrainingWorkerConfig: max_epochs: int = 5 loss_function: callable = None learning_rate: np.float64 = 1e-3 + scheduler_patience: int = 10 + scheduler_factor: float = 0.5 validation_interval: int = 2 batch_size: int = 1 results_path_folder: str = str(Path.home() / Path("cellseg3d/training")) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 57b3b0bd..9a100dc2 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -408,21 +408,21 @@ def __init__( self, entries: Optional[list] = None, parent: Optional[QWidget] = None, - label: Optional[str] = None, + text_label: Optional[str] = None, fixed: Optional[bool] = True, ): """Args: entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None - label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well + text_label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True. """ super().__init__(parent) self.label = None if entries is not None: self.addItems(entries) - if label is not None: - self.label = QLabel(label) + if text_label is not None: + self.label = QLabel(text_label) if fixed: self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) @@ -473,9 +473,10 @@ def __init__( self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - self.text_label = None + self.label = None self.container = ContainerWidget( - # parent=self.parent + # parent=self.parent, + b=0, ) self._divide_factor = divide_factor @@ -498,7 +499,7 @@ def __init__( ) if text_label is not None: - self.text_label = make_label(text_label, parent=self) + self.label = make_label(text_label, parent=self) if default < lower: self._warn_outside_bounds(default) @@ -517,14 +518,14 @@ def __init__( def set_visibility(self, visible: bool): self.container.setVisible(visible) self.setVisible(visible) - self.text_label.setVisible(visible) + self.label.setVisible(visible) def _build_container(self): - if self.text_label is not None: + if self.label is not None: add_widgets( self.container.layout, [ - self.text_label, + self.label, combine_blocks(self._value_label, self, b=0), ], ) @@ -568,8 +569,8 @@ def tooltips(self, tooltip: str): self.setToolTip(tooltip) self._value_label.setToolTip(tooltip) - if self.text_label is not None: - self.text_label.setToolTip(tooltip) + if self.label is not None: + self.label.setToolTip(tooltip) @property def slider_value(self): @@ -739,7 +740,9 @@ def __init__( self.image = None self.layer_type = layer_type - self.layer_list = DropdownMenu(parent=self, label=name, fixed=False) + self.layer_list = DropdownMenu( + parent=self, text_label=name, fixed=False + ) # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? self._viewer.layers.events.inserted.connect(partial(self._add_layer)) @@ -1044,7 +1047,7 @@ def __init__( step: Optional[float] = 1.0, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[float]): minimum value, defaults to 0 @@ -1053,7 +1056,7 @@ def __init__( step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed - label (Optional[str]): if provided, creates a label with the chosen title to use with the counter + text_label (Optional[str]): if provided, creates a label with the chosen title to use with the counter """ super().__init__(parent) @@ -1061,8 +1064,8 @@ def __init__( self.layout = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) self.valueChanged.connect(self._update_step) def _update_step(self): @@ -1122,7 +1125,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[int]): minimum value, defaults to 0 @@ -1138,8 +1141,8 @@ def __init__( self.label = None self.container = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) @property def tooltips(self): From 39b4391711c8cf74052cc16ee4be2e1c268dbc2e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 31 Mar 2023 15:45:00 +0200 Subject: [PATCH 225/577] Update assess_instance.ipynb --- notebooks/assess_instance.ipynb | 162 ++++++++++++++++++++------------ 1 file changed, 101 insertions(+), 61 deletions(-) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 169775f5..3dae22a9 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -49,20 +49,10 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -72,15 +62,16 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "# prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "# prediction_resized = resize(prediction, zoom)\n", - "prediction_resized = prediction # for trailmap\n", + "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -103,7 +94,7 @@ { "data": { "text/plain": [ - "0.7538125057831502" + "0.8592223181276479" ] }, "execution_count": 4, @@ -189,7 +180,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -216,24 +207,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,590 - Mapping labels...\n" + "2023-03-31 15:37:19,775 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3699.66it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", - "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", - "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" + "2023-03-31 15:37:19,812 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,815 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-31 15:37:19,816 - Percent of fused neurons found: 36.80%\n", + "2023-03-31 15:37:19,817 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -246,15 +237,15 @@ { "data": { "text/plain": [ - "(63,\n", - " 45,\n", - " 16,\n", - " 16,\n", - " 0.819027731148306,\n", - " 0.8401649108992161,\n", - " 0.83609908334452,\n", - " 0.8066092803671974,\n", - " 0.98)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 7, @@ -280,24 +271,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,732 - Mapping labels...\n" + "2023-03-31 15:37:19,919 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3992.79it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", - "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", - "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" + "2023-03-31 15:37:19,949 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,952 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-31 15:37:19,953 - Percent of fused neurons found: 34.40%\n", + "2023-03-31 15:37:19,953 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -310,15 +301,15 @@ { "data": { "text/plain": [ - "(77,\n", - " 34,\n", + "(68,\n", + " 43,\n", " 13,\n", - " 9,\n", - " 0.728461197681457,\n", - " 0.8885669859686413,\n", - " 0.8950588507577087,\n", - " 0.7472814623489069,\n", - " 0.878614359974009)" + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 8, @@ -344,6 +335,40 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:37:21,076 - build program: kernel 'gaussian_blur_separable_3d' was part of a lengthy source build resulting from a binary cache miss (0.88 s)\n", + "2023-03-31 15:37:21,514 - build program: kernel 'copy_3d' was part of a lengthy source build resulting from a binary cache miss (0.42 s)\n", + "2023-03-31 15:37:22,021 - build program: kernel 'detect_maxima_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:22,642 - build program: kernel 'minimum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.59 s)\n", + "2023-03-31 15:37:23,117 - build program: kernel 'minimum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:23,651 - build program: kernel 'minimum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,188 - build program: kernel 'maximum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,801 - build program: kernel 'maximum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.60 s)\n", + "2023-03-31 15:37:25,263 - build program: kernel 'maximum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:25,766 - build program: kernel 'histogram_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:26,256 - build program: kernel 'sum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:26,699 - build program: kernel 'greater_constant_3d' was part of a lengthy source build resulting from a binary cache miss (0.43 s)\n", + "2023-03-31 15:37:27,158 - build program: kernel 'binary_and_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:27,635 - build program: kernel 'add_image_and_scalar_3d' was part of a lengthy source build resulting from a binary cache miss (0.47 s)\n", + "2023-03-31 15:37:28,128 - build program: kernel 'set_nonzero_pixels_to_pixelindex' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:28,580 - build program: kernel 'set_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:29,076 - build program: kernel 'nonzero_minimum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:29,551 - build program: kernel 'set_2d' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:30,035 - build program: kernel 'flag_existing_labels' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:30,544 - build program: kernel 'set_column_2d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:31,033 - build program: kernel 'sum_reduction_x' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:31,572 - build program: kernel 'block_enumerate' was part of a lengthy source build resulting from a binary cache miss (0.53 s)\n", + "2023-03-31 15:37:32,094 - build program: kernel 'replace_intensities' was part of a lengthy source build resulting from a binary cache miss (0.51 s)\n", + "2023-03-31 15:37:32,685 - build program: kernel 'add_images_weighted_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:33,256 - build program: kernel 'onlyzero_overwrite_maximum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.56 s)\n", + "2023-03-31 15:37:33,845 - build program: kernel 'onlyzero_overwrite_maximum_diamond_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:34,369 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:34,888 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n" + ] + }, { "data": { "text/plain": [ @@ -503,24 +528,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,241 - Mapping labels...\n" + "2023-03-31 15:37:36,854 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 611.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", - "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", - "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" + "2023-03-31 15:37:37,087 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:37,098 - Percent of non-fused neurons found: 87.20%\n", + "2023-03-31 15:37:37,104 - Percent of fused neurons found: 1.60%\n", + "2023-03-31 15:37:37,114 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -533,15 +558,15 @@ { "data": { "text/plain": [ - "(102,\n", + "(109,\n", + " 2,\n", + " 13,\n", " 8,\n", - " 14,\n", - " 16,\n", - " 0.708505702558253,\n", - " 0.8832633585884945,\n", - " 0.9759871495093808,\n", - " 0.6670483272595948,\n", - " 0.8653680990771155)" + " 0.8285521200005869,\n", + " 0.8809251900364068,\n", + " 0.9838709677419355,\n", + " 0.782258064516129,\n", + " 1.0)" ] }, "execution_count": 13, @@ -565,10 +590,25 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:40:34,683 - No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'\n" + ] + } + ], "source": [ "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -587,7 +627,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" } }, "nbformat": 4, From 1864976e4e92141a028321bdb3e6d935dc94982b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 11:09:30 +0200 Subject: [PATCH 226/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df43b4fa..df67a187 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png +notebooks/instance_test.ipynb *.prof #include test data From c5136029c88b83c04d06d9f959e34951283e41d2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 14:27:21 +0200 Subject: [PATCH 227/577] Started adding WNet --- napari_cellseg3d/code_models/model_workers.py | 4 +- .../code_models/models/model_SwinUNetR.py | 29 +- .../code_models/models/model_TRAILMAP_MS.py | 15 +- .../code_models/models/model_WNet.py | 27 ++ .../pretrained/pretrained_model_urls.json | 1 + .../code_models/models/wnet/__init__.py | 0 .../code_models/models/wnet/crf.py | 112 ++++++ .../code_models/models/wnet/model.py | 189 ++++++++++ .../code_models/models/wnet/soft_Ncuts.py | 352 ++++++++++++++++++ napari_cellseg3d/config.py | 22 ++ 10 files changed, 739 insertions(+), 12 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/model_WNet.py create mode 100644 napari_cellseg3d/code_models/models/wnet/__init__.py create mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py create mode 100644 napari_cellseg3d/code_models/models/wnet/model.py create mode 100644 napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index bca24035..7a45c47e 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -761,7 +761,9 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=[128, 128, 128], + input_img_size=dims, + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) # try: model = model.to(self.config.device) diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index fe4d380c..f38409b8 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,7 @@ from monai.networks.nets import SwinUNETR +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER class SwinUNETR_(SwinUNETR): @@ -6,14 +9,24 @@ class SwinUNETR_(SwinUNETR): weights_file = "Swin64_best_metric.pth" def __init__(self, input_img_size, use_checkpoint=True, **kwargs): - super().__init__( - input_img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - **kwargs - ) + try: + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs, + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index e3ca00a6..1123173a 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,4 +1,7 @@ from napari_cellseg3d.code_models.models.unet.model import UNet3D +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER class TRAILMAP_MS_(UNet3D): @@ -8,9 +11,15 @@ class TRAILMAP_MS_(UNet3D): # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) def __init__(self, in_channels=1, out_channels=1, **kwargs): - super().__init__( - in_channels=in_channels, out_channels=out_channels, **kwargs - ) + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py new file mode 100644 index 00000000..63a91b10 --- /dev/null +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -0,0 +1,27 @@ +from napari_cellseg3d.code_models.models.wnet.model import WNet + + +class WNet_(WNet): + use_default_training = False + weights_file = "wnet.pth" + + def __init__( + self, + in_channels=1, + out_channels=1, + num_classes=2, + device="cpu", + **kwargs + ): + super().__init__( + device=device, + in_channels=in_channels, + out_channels=out_channels, + num_classes=num_classes, + ) + + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + # dec = self.forward_decoder(enc) + return enc diff --git a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json index cd0782fb..cde5e332 100644 --- a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json @@ -3,5 +3,6 @@ "SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet.tar.gz", "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet.tar.gz", "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz", + "WNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet.tar.gz", "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" } diff --git a/napari_cellseg3d/code_models/models/wnet/__init__.py b/napari_cellseg3d/code_models/models/wnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py new file mode 100644 index 00000000..ca11fba2 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -0,0 +1,112 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" + +import numpy as np +import pydensecrf.densecrf as dcrf +from pydensecrf.utils import ( + unary_from_softmax, + create_pairwise_gaussian, + create_pairwise_bilateral, +) + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py new file mode 100644 index 00000000..585ea0dd --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -0,0 +1,189 @@ +""" +Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. +The model performs unsupervised segmentation of 3D images. +""" + +import torch +import torch.nn as nn + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", +] + + +class WNet(nn.Module): + """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. + The model performs unsupervised segmentation of 3D images. + It first encodes the input image into a latent space using the U-Net UEncoder, then decodes it back to the original image using the U-Net UDecoder. + """ + + def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + super(WNet, self).__init__() + self.device = device + self.encoder = UNet(device, in_channels, num_classes, encoder=True) + self.decoder = UNet(device, num_classes, out_channels, encoder=False) + + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + dec = self.forward_decoder(enc) + return enc, dec + + def forward_encoder(self, x): + """Forward pass of the encoder part of the W-Net model.""" + enc = self.encoder(x) + return enc + + def forward_decoder(self, enc): + """Forward pass of the decoder part of the W-Net model.""" + dec = self.decoder(enc) + return dec + + +class UNet(nn.Module): + """Half of the W-Net model, based on the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels, encoder=True): + super(UNet, self).__init__() + self.device = device + self.in_b = InBlock(device, in_channels, 64) + self.conv1 = Block(device, 64, 128) + self.conv2 = Block(device, 128, 256) + self.conv3 = Block(device, 256, 512) + self.bot = Block(device, 512, 1024) + self.deconv1 = Block(device, 1024, 512) + self.deconv2 = Block(device, 512, 256) + self.deconv3 = Block(device, 256, 128) + self.out_b = OutBlock(device, 128, out_channels) + + self.sm = nn.Softmax(dim=1).to(device) + self.encoder = encoder + + def forward(self, x): + """Forward pass of the U-Net model.""" + in_b = self.in_b(x.to(self.device)) + c1 = self.conv1(nn.MaxPool3d(2)(in_b)) + c2 = self.conv2(nn.MaxPool3d(2)(c1)) + c3 = self.conv3(nn.MaxPool3d(2)(c2)) + x = self.bot(nn.MaxPool3d(2)(c3)) + x = self.deconv1( + torch.cat( + [ + c3, + nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.deconv2( + torch.cat( + [ + c2, + nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.deconv3( + torch.cat( + [ + c1, + nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.out_b( + torch.cat( + [ + in_b, + nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + if self.encoder: + x = self.sm(x) + return x + + +class InBlock(nn.Module): + """Input block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(InBlock, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the input block.""" + return self.module(x.to(self.device)) + + +class Block(nn.Module): + """Basic block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(Block, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), + nn.Conv3d(in_channels, out_channels, 1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), + nn.Conv3d(out_channels, out_channels, 1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the basic block.""" + return self.module(x.to(self.device)) + + +class OutBlock(nn.Module): + """Output block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(OutBlock, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, 64, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(64, device=device), + nn.Conv3d(64, 64, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(64, device=device), + nn.Conv3d(64, out_channels, 1, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the output block.""" + return self.module(x.to(self.device)) diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py new file mode 100644 index 00000000..6a625355 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -0,0 +1,352 @@ +""" +Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. +The implementation was adapted and approximated to reduce computational and memory cost. +This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from scipy.stats import norm + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", + "Jianbo Shi", + "Jitendra Malik", + "Frank Odom", +] + + +class SoftNCutsLoss(nn.Module): + """Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. + + Args: + data_shape (H, W, D): shape of the images as a tuple. + o_i (scalar): scale of the gaussian kernel of pixels brightness. + o_x (scalar): scale of the gaussian kernel of pixels spacial distance. + radius (scalar): radius of pixels for which we compute the weights + """ + + def __init__(self, data_shape, device, o_i, o_x, radius=None): + super(SoftNCutsLoss, self).__init__() + self.o_i = o_i + self.o_x = o_x + self.radius = radius + self.H = data_shape[0] + self.W = data_shape[1] + self.D = data_shape[2] + self.device = device + + if self.radius is None: + self.radius = min( + max(5, math.ceil(min(self.H, self.W, self.D) / 20)), + self.H, + self.W, + self.D, + ) + + # self.distances, self.indexes = self.get_distances() + + """ + + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration + distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) + distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) + distances_D = torch.tensor(range(self.D)).expand(self.D, self.D) # (D, D) + + # Compute in cuda if possible + if torch.cuda.is_available(): + distances_H = distances_H.cuda() + distances_W = distances_W.cuda() + distances_D = distances_D.cuda() + + distances_H = torch.abs(torch.subtract(distances_H, distances_H.T)) # (H, H) + distances_W = torch.abs(torch.subtract(distances_W, distances_W.T)) # (W, W) + distances_D = torch.abs(torch.subtract(distances_D, distances_D.T)) # (D, D) + + distances_H = distances_H.view(self.H, 1, 1, self.H, 1, 1).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (H, 1, 1, H, 1, 1) -> (H, W, D, H, W, D) + distances_W = distances_W.view(1, self.W, 1, 1, self.W, 1).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (1, W, 1, 1, W, 1) -> (H, W, D, H, W, D) + distances_D = distances_D.view(1, 1, self.D, 1, 1, self.D).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (1, 1, D, 1, 1, D) -> (H, W, D, H, W, D) + + mask_H = torch.le(distances_H, self.radius).bool() # (H, W, D, H, W, D) + mask_W = torch.le(distances_W, self.radius).bool() # (H, W, D, H, W, D) + mask_D = torch.le(distances_D, self.radius).bool() # (H, W, D, H, W, D) + + distances_H = (distances_H * mask_H) # (H, W, D, H, W, D) + distances_W = (distances_W * mask_W) # (H, W, D, H, W, D) + distances_D = (distances_D * mask_D) # (H, W, D, H, W, D) + + mask_H =mask_H.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + mask_W =mask_W.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + mask_D =mask_D.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + + distances_H = distances_H.pow(2) # (H, W, D, H, W, D) + distances_W = distances_W.pow(2) # (H, W, D, H, W, D) + distances_D = distances_D.pow(2) # (H, W, D, H, W, D) + + squared_distances = torch.add( + torch.add(distances_H, distances_W), + distances_D, + ) # (H, W, D, H, W, D) + + squared_distances = squared_distances.flatten(0, 2).flatten( + 1, 3 + ) # (H*W*D, H*W*D) + + # Mask to only keep the weights for the pixels in the radius + self.mask = torch.le(squared_distances, self.radius**2).bool() # (H*W*D, H*W*D) + + # Add all masks to get the final mask + self.mask = self.mask.logical_and(mask_H).logical_and(mask_W).logical_and(mask_D) # (H*W*D, H*W*D) + + W_X = torch.exp( + torch.neg(torch.div(squared_distances, self.o_x)) + ) # (H*W*D, H*W*D) + + self.W_X = torch.mul(W_X, self.mask) # (H*W*D, H*W*D) + """ + + def forward(self, labels, inputs): + """Forward pass of the Soft N-Cuts loss. + + Args: + labels (torch.Tensor): Tensor of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + The Soft N-Cuts loss of shape (N,). + """ + inputs.shape[0] + inputs.shape[1] + K = labels.shape[1] + + labels.to(self.device) + inputs.to(self.device) + + loss = 0 + + kernel = self.gaussian_kernel(self.radius, self.o_x).to(self.device) + + for k in range(K): + # Compute the average pixel value for this class, and the difference from each pixel + class_probs = labels[:, k].unsqueeze(1) + class_mean = torch.mean( + inputs * class_probs, dim=(2, 3, 4), keepdim=True + ) / torch.add( + torch.mean(class_probs, dim=(2, 3, 4), keepdim=True), 1e-5 + ) + diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1) + + # Weight the loss by the difference from the class average. + weights = torch.exp(diff.pow(2).mul(-1 / self.o_i**2)) + + numerator = torch.sum( + class_probs + * F.conv3d(class_probs * weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + denominator = torch.sum( + class_probs * F.conv3d(weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + loss += nn.L1Loss()( + numerator / torch.add(denominator, 1e-6), + torch.zeros_like(numerator), + ) + + return K - loss + + """ + for k in range(K): + Ak = labels[:, k, :, :, :] # (N, H, W, D) + flatted_Ak = Ak.view(N, -1) # (N, H*W*D) + + # Compute the numerator of the Soft N-Cuts loss for k + flatted_Ak_unsqueeze = flatted_Ak.unsqueeze(1) # (N, 1, H*W*D) + transposed_Ak = torch.transpose(flatted_Ak_unsqueeze, 1, 2) # (N, H*W*D, 1) + probs = torch.bmm(transposed_Ak, flatted_Ak_unsqueeze) # (N, H*W*D, H*W*D) + probs_unsqueeze_expanded = probs.unsqueeze(1) # (N, 1, H*W*D, H*W*D) + numerator_elements = torch.mul( + probs_unsqueeze_expanded, weights + ) # (N, C, H*W*D, H*W*D) + numerator = torch.sum(numerator_elements, dim=(2, 3)) # (N, C) + + # Compute the denominator of the Soft N-Cuts loss for k + expanded_flatted_Ak = flatted_Ak.expand( + -1, self.H * self.W * self.D + ) # (N, H*W*D, H*W*D) + e_f_Ak_unsqueeze_expanded = expanded_flatted_Ak.unsqueeze( + 1 + ) # (N, 1, H*W*D, H*W*D) + denominator_elements = torch.mul( + e_f_Ak_unsqueeze_expanded, weights + ) # (N, C, H*W*D, H*W*D) + denominator = torch.sum(denominator_elements, dim=(2, 3)) # (N, C) + + # Compute the Soft N-Cuts loss for k + division = torch.div(numerator, torch.add(denominator, 1e-8)) # (N, C) + loss = torch.sum(division, dim=1) # (N,) + losses.append(loss) + + loss = torch.sum(torch.stack(losses, dim=0), dim=0) # (N,) + + return torch.add(torch.neg(loss), K) + """ + + def gaussian_kernel(self, radius, sigma): + """Computes the Gaussian kernel. + + Args: + radius (int): The radius of the kernel. + sigma (float): The standard deviation of the Gaussian distribution. + + Returns: + The Gaussian kernel of shape (1, 1, 2*radius+1, 2*radius+1, 2*radius+1). + """ + x_2 = np.linspace(-radius, radius, 2 * radius + 1) ** 2 + dist = ( + np.sqrt( + x_2.reshape(-1, 1, 1) + + x_2.reshape(1, -1, 1) + + x_2.reshape(1, 1, -1) + ) + / sigma + ) + kernel = norm.pdf(dist) / norm.pdf(0) + kernel = torch.from_numpy(kernel.astype(np.float32)) + kernel = kernel.view( + (1, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]) + ) + + return kernel + + def get_distances(self): + """Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration. + + Returns: + distances (dict): for each pixel index, we get the distances to the pixels in a radius around it. + """ + + distances = dict() + indexes = np.array( + [ + (i, j, k) + for i in range(self.H) + for j in range(self.W) + for k in range(self.D) + ] + ) + + for i in indexes: + iTuple = (i[0], i[1], i[2]) + distances[iTuple] = dict() + + sliceD = indexes[ + i[0] * self.H + + i[1] * self.W + + max(0, i[2] - self.radius) : i[0] * self.H + + i[1] * self.W + + min(self.D, i[2] + self.radius) + ] + sliceW = indexes[ + i[0] * self.H + + max(0, i[1] - self.radius) * self.W + + i[2] : i[0] * self.H + + min(self.W, i[1] + self.radius) * self.W + + i[2] : self.D + ] + sliceH = indexes[ + max(0, i[0] - self.radius) * self.H + + i[1] * self.W + + i[2] : min(self.H, i[0] + self.radius) * self.H + + i[1] * self.W + + i[2] : self.D * self.W + ] + + for j in np.concatenate((sliceD, sliceW, sliceH)): + jTuple = (j[0], j[1], j[2]) + distance = np.linalg.norm(i - j) + if distance > self.radius: + continue + distance = math.exp(-(distance**2) / (self.o_x**2)) + + if jTuple not in distances: + distances[iTuple][jTuple] = distance + + return distances, indexes + + def get_weights(self, inputs): + """Computes the weights matrix for the Soft N-Cuts loss. + + Args: + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + list: List of the weights dict for each image in the batch. + """ + + """ + weights = [] + for n in range(inputs.shape[0]): + weightsChannel = [] + for c in range(inputs.shape[1]): + weightsImage = dict() + for i in self.indexes: + iTuple = (i[0], i[1], i[2]) + weightsImage[iTuple] = dict() + for j in self.indexes: + jTuple = (j[0], j[1], j[2]) + if iTuple in self.distances and jTuple in self.distances[i]: + brightness = ( + inputs[n][c][i[0]][i[1]][i[2]] + - inputs[n][c][j[0]][j[1]][j[2]] + ) ** 2 + brightness = math.exp(-brightness / self.o_i**2) + weightsImage[iTuple][jTuple] = ( + self.distances[iTuple][jTuple] * brightness + ) + + weightsChannel.append(weightsImage) + + weights.append(weightsChannel) + + return weights + + """ + + # Compute the brightness distance of the pixels + flatted_inputs = inputs.view( + inputs.shape[0], inputs.shape[1], -1 + ) # (N, C, H*W*D) + I_diff = torch.subtract( + flatted_inputs.unsqueeze(3), flatted_inputs.unsqueeze(2) + ) # (N, C, H*W*D, H*W*D) + masked_I_diff = torch.mul(I_diff, self.mask) # (N, C, H*W*D, H*W*D) + squared_I_diff = torch.pow(masked_I_diff, 2) # (N, C, H*W*D, H*W*D) + + W_I = torch.exp( + torch.neg(torch.div(squared_I_diff, self.o_i)) + ) # (N, C, H*W*D, H*W*D) + W_I = torch.mul(W_I, self.mask) # (N, C, H*W*D, H*W*D) + + # Get the spatial distance of the pixels + unsqueezed_W_X = self.W_X.view( + 1, 1, self.W_X.shape[0], self.W_X.shape[1] + ) # (1, 1, H*W*D, H*W*D) + + W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) + return W diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index afc16bd3..4eaddb93 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -14,6 +14,7 @@ from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ +from napari_cellseg3d.code_models.models.model_WNet import WNet_ from napari_cellseg3d.utils import LOGGER @@ -28,6 +29,7 @@ # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS_, "SwinUNetR": SwinUNETR_, + "WNet": WNet_, # "test" : DO NOT USE, reserved for testing } @@ -71,10 +73,12 @@ class ModelInfo: Args: name (str): name of the model model_input_size (Optional[List[int]]): input size of the model + num_classes (int): number of classes for the model """ name: str = next(iter(MODEL_LIST)) model_input_size: Optional[List[int]] = None + num_classes: int = 2 def get_model(self): try: @@ -240,3 +244,21 @@ class TrainingWorkerConfig: sample_size: List[int] = None do_augmentation: bool = True deterministic_config: DeterministicConfig = DeterministicConfig() + + +################ +# CRF config for WNet +################ + + +@dataclass +class WNetCRFConfig: + "Class to store parameters of WNet CRF post processing" + + # CRF + sa = 10 # 50 + sb = 10 + sg = 1 + w1 = 10 # 50 + w2 = 10 + n_iter = 5 From e8c13aa43d4cdc91f7489abc12c2babf7f8978e3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 20 Apr 2023 11:12:59 +0200 Subject: [PATCH 228/577] Specify no grad in inference --- napari_cellseg3d/code_models/model_workers.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 7a45c47e..6bc088e6 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -490,16 +490,17 @@ def model_output_wrapper(inputs): result = model(inputs) return post_process_transforms(result) - outputs = sliding_window_inference( - inputs, - roi_size=window_size, - sw_batch_size=1, # TODO add param - predictor=model_output_wrapper, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) + with torch.no_grad(): + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) except Exception as e: logger.error(e, exc_info=True) logger.debug("failed to run sliding window inference") @@ -1412,16 +1413,17 @@ def train(self): ) self.log("Performing validation...") try: - val_outputs = sliding_window_inference( - val_inputs, - roi_size=size, - sw_batch_size=self.config.batch_size, - predictor=model, - overlap=0.25, - sw_device=self.config.device, - device=self.config.device, - progress=True, - ) + with torch.no_grad(): + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) except Exception as e: self.raise_error(e, "Error during validation") logger.debug( From 73f6c8e8519ee6bccaba6aa8a3d98a79a8a734f1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 22 Apr 2023 14:12:32 +0200 Subject: [PATCH 229/577] First functional WNet inference, no CRF --- napari_cellseg3d/code_models/model_workers.py | 46 +++++++++++---- .../code_models/models/model_WNet.py | 3 +- .../code_plugins/plugin_model_inference.py | 57 +++++++++++-------- 3 files changed, 71 insertions(+), 35 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 6bc088e6..33f0ee12 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -199,7 +199,7 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None - stats: ImageStats = None + stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -541,7 +541,10 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - semantic_labels = np.swapaxes(semantic_labels, 0, 2) + total_dims = len(semantic_labels.shape) - 3 + semantic_labels = np.swapaxes( + semantic_labels, 0 + total_dims, 2 + total_dims + ) return InferenceResult( image_id=i + 1, @@ -582,8 +585,10 @@ def save_image( ): if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" + filetype = self.config.filetype else: original_filename = "_" + filetype = "" time = utils.get_date_time() @@ -594,7 +599,7 @@ def save_image( + original_filename + self.config.model_info.name + f"_{time}_" - + self.config.filetype + + filetype ) try: imwrite(file_path, image) @@ -619,22 +624,35 @@ def aniso_transform(self, image): else: return image - def instance_seg(self, to_instance, image_id=0, original_filename="layer"): + def instance_seg( + self, to_instance, image_id=0, original_filename="layer", channel=None + ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method instance_labels = method.run_method(image=to_instance) + if channel is not None: + channel_id = f"_{channel}" + else: + channel_id = "" + + if self.config.filetype == "": + filetype = "" + else: + filetype = "_" + self.config.filetype + instance_filepath = ( self.config.results_path + "/" + f"Instance_seg_labels_{image_id}_" + original_filename + + channel_id + "_" + self.config.model_info.name - + f"_{utils.get_date_time()}_" - + self.config.filetype + + f"_{utils.get_date_time()}" + + filetype ) imwrite(instance_filepath, instance_labels) @@ -699,13 +717,21 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, stats = self.get_instance_result(out, from_layer=True) + instance_labels_results = [] + stats_results = [] + + for channel in out: + instance_labels, stats = self.get_instance_result( + channel, from_layer=True + ) + instance_labels_results.append(instance_labels) + stats_results.append(stats) return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels, + instance_labels=instance_labels_results, from_layer=True, - stats=stats, + stats=stats_results, ) # @thread_worker(connect={"errored": self.raise_error}) @@ -762,7 +788,7 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=dims, + input_img_size=[dims, dims, dims], device=self.config.device, num_classes=self.config.model_info.num_classes, ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 63a91b10..dffa3b44 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,8 @@ def __init__( ) def forward(self, x): - """Forward pass of the W-Net model.""" + """Forward ENCODER pass of the W-Net model. + Done this way to allow inference on the encoder only when called by sliding_window_inference.""" enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 44e70a76..522f91bb 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -733,37 +733,46 @@ def on_yield(self, result: InferenceResult): ) if result.instance_labels is not None: - labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + for i, labels in enumerate(result.instance_labels): + # labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - stats = result.stats + from napari_cellseg3d.utils import LOGGER as log - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + log.debug(f"len stats : {len(result.stats)}") - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + for i, stats in enumerate(result.stats): + # stats = result.stats - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) + + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) + + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) except Exception as e: self.on_error(e) From 89f3701520665211e9e9aabda1962760b8539c27 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:48:12 +0200 Subject: [PATCH 230/577] Create test_models.py --- napari_cellseg3d/_tests/test_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 napari_cellseg3d/_tests/test_models.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py new file mode 100644 index 00000000..e2ba32e0 --- /dev/null +++ b/napari_cellseg3d/_tests/test_models.py @@ -0,0 +1,13 @@ +from napari_cellseg3d.config import MODEL_LIST + + +def test_model_list(): + for model_name in MODEL_LIST.keys(): + dims = 128 + test = MODEL_LIST[model_name]( + input_img_size=[dims, dims, dims], + in_channels=1, + out_channels=1, + dropout_prob=0.3, + ) + assert isinstance(test, MODEL_LIST[model_name]) From 1824debf0ec89d665e376a0e2d718df927abb51b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:42:56 +0200 Subject: [PATCH 231/577] Run full suite of pre-commit hooks --- docs/res/guides/custom_model_template.rst | 2 -- napari_cellseg3d/code_models/model_instance_seg.py | 2 ++ napari_cellseg3d/code_models/model_workers.py | 5 ++--- napari_cellseg3d/code_models/models/model_SwinUNetR.py | 1 + napari_cellseg3d/code_models/models/model_WNet.py | 3 ++- napari_cellseg3d/code_models/models/wnet/crf.py | 8 +++----- napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py | 8 ++++---- napari_cellseg3d/config.py | 1 - napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 - 9 files changed, 14 insertions(+), 17 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index 9bad49b0..218795b1 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -9,5 +9,3 @@ To add a custom model, you will need a **.py** file with the following structure **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute :: - - diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 0c87a2df..cc7fac90 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -12,6 +12,8 @@ # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 33f0ee12..dce2f452 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1,8 +1,8 @@ import platform +import typing as t from dataclasses import dataclass from math import ceil from pathlib import Path -import typing as t import numpy as np import torch @@ -40,10 +40,9 @@ ) from monai.utils import set_determinism +# from napari.qt.threading import thread_worker # threads from napari.qt.threading import GeneratorWorker - -# from napari.qt.threading import thread_worker from napari.qt.threading import WorkerBaseSignals # Qt diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index f38409b8..05819e22 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,5 @@ from monai.networks.nets import SwinUNETR + from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index dffa3b44..750b8bdb 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -22,7 +22,8 @@ def __init__( def forward(self, x): """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference.""" + Done this way to allow inference on the encoder only when called by sliding_window_inference. + """ enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index ca11fba2..2ac0875d 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,11 +12,9 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - unary_from_softmax, - create_pairwise_gaussian, - create_pairwise_bilateral, -) +from pydensecrf.utils import create_pairwise_bilateral +from pydensecrf.utils import create_pairwise_gaussian +from pydensecrf.utils import unary_from_softmax __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 6a625355..4e84579f 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -1,15 +1,15 @@ """ Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. -The implementation was adapted and approximated to reduce computational and memory cost. +The implementation was adapted and approximated to reduce computational and memory cost. This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. """ import math + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -import numpy as np from scipy.stats import norm __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" @@ -56,7 +56,7 @@ def __init__(self, data_shape, device, o_i, o_x, radius=None): # self.distances, self.indexes = self.get_distances() """ - + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 4eaddb93..43f961f4 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -15,7 +15,6 @@ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ from napari_cellseg3d.code_models.models.model_WNet import WNet_ - from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..48249a94 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,5 +1,4 @@ import os - import napari import numpy as np import scipy.ndimage as ndimage From e0eab0f1a6be171d992eb67c82602b9110670165 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:27:18 +0200 Subject: [PATCH 232/577] Patch for tests action + style --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/code_models/model_instance_seg.py | 6 ++++-- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 + napari_cellseg3d/utils.py | 1 + 5 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index ea0a1e46..88a67ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -16,6 +16,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet workflow_dispatch: jobs: diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index cc7fac90..2f10aa1f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -12,14 +12,16 @@ # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 750b8bdb..4a9ff70d 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -11,7 +11,7 @@ def __init__( out_channels=1, num_classes=2, device="cpu", - **kwargs + **kwargs, ): super().__init__( device=device, diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 48249a94..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ import os + import napari import numpy as np import scipy.ndimage as ndimage diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 5683c541..9fbe6d7a 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,6 +2,7 @@ import warnings from datetime import datetime from pathlib import Path + import numpy as np from monai.transforms import Zoom from skimage import io From 5ffbc6289431dab0ae366c9a62cac894f200f270 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 16:03:29 +0200 Subject: [PATCH 233/577] Add softNCuts basic test --- napari_cellseg3d/_tests/test_models.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index e2ba32e0..9280b230 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,3 +1,6 @@ +import torch + +from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST @@ -11,3 +14,20 @@ def test_model_list(): dropout_prob=0.3, ) assert isinstance(test, MODEL_LIST[model_name]) + + +def test_soft_ncuts_loss(): + dims = 8 + labels = torch.rand([1, 1, dims, dims, dims]) + + loss = SoftNCutsLoss( + data_shape=[dims, dims, dims], + device="cpu", + o_i=4, + o_x=4, + radius=2, + ) + + res = loss.forward(labels, labels) + assert isinstance(res, torch.Tensor) + # assert res > 0 From 5773065f549be148925c6df4920c246eb98a0306 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 09:41:15 +0200 Subject: [PATCH 234/577] Added crf Co-Authored-By: Nevexios <72894299+nevexios@users.noreply.github.com> --- napari_cellseg3d/code_models/crf.py | 122 ++++++++++++++++++++++++++++ pyproject.toml | 3 + 2 files changed, 125 insertions(+) create mode 100644 napari_cellseg3d/code_models/crf.py diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py new file mode 100644 index 00000000..13f489c7 --- /dev/null +++ b/napari_cellseg3d/code_models/crf.py @@ -0,0 +1,122 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" + +from warnings import warn + +import numpy as np + +try: + import pydensecrf.densecrf as dcrf + from pydensecrf.utils import create_pairwise_bilateral + from pydensecrf.utils import create_pairwise_gaussian + from pydensecrf.utils import unary_from_softmax + + CRF_INSTALLED = True +except ImportError: + warn( + "pydensecrf not installed, CRF post-processing will not be available. " + "Please install by running pip install cellseg3d[crf]" + ) + CRF_INSTALLED = False + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) diff --git a/pyproject.toml b/pyproject.toml index d2a2adbb..d9a46ccf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ profile = "black" line_length = 79 [project.optional-dependencies] +crf = [ +"git+https://github.com/lucasb-eyer/pydensecrf.git", +] dev = [ "isort", "black", From cb52b930900dc5802290826ce344c60b69d45c96 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 10:08:46 +0200 Subject: [PATCH 235/577] More pre-commit checks --- .pre-commit-config.yaml | 10 +-- napari_cellseg3d/_tests/fixtures.py | 6 +- napari_cellseg3d/_tests/test_plugin_utils.py | 6 +- napari_cellseg3d/_tests/test_utils.py | 25 ++++--- .../_tests/test_weight_download.py | 6 +- napari_cellseg3d/code_models/crf.py | 11 +-- .../code_models/model_framework.py | 10 ++- .../code_models/model_instance_seg.py | 13 ++-- napari_cellseg3d/code_models/model_workers.py | 11 +-- .../code_models/models/wnet/crf.py | 8 ++- napari_cellseg3d/code_plugins/plugin_base.py | 3 +- .../code_plugins/plugin_convert.py | 16 ++--- napari_cellseg3d/code_plugins/plugin_crop.py | 5 +- .../code_plugins/plugin_model_inference.py | 17 +++-- .../code_plugins/plugin_model_training.py | 8 +-- .../code_plugins/plugin_review.py | 13 ++-- .../code_plugins/plugin_review_dock.py | 5 +- .../code_plugins/plugin_utilities.py | 4 +- napari_cellseg3d/config.py | 5 +- .../dev_scripts/artefact_labeling.py | 3 +- napari_cellseg3d/dev_scripts/convert.py | 3 +- .../dev_scripts/correct_labels.py | 3 +- napari_cellseg3d/interface.py | 67 +++++++++---------- napari_cellseg3d/utils.py | 59 ++++++++-------- pyproject.toml | 7 +- 25 files changed, 166 insertions(+), 158 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..61ecaae5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,11 +5,11 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] +# - repo: https://github.com/pycqa/isort +# rev: 5.12.0 +# hooks: +# - id: isort +# args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index bd6b0ac7..b3044799 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -1,7 +1,7 @@ -import warnings - from qtpy.QtWidgets import QTextEdit +from napari_cellseg3d.utils import LOGGER as logger + class LogFixture(QTextEdit): """Fixture for testing, replaces napari_cellseg3d.interface.Log in model_workers during testing""" @@ -13,7 +13,7 @@ def print_and_log(self, text, printing=None): print(text) def warn(self, warning): - warnings.warn(warning) + logger.warning(warning) def error(self, e): raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 7908e8b4..584be4d7 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -3,8 +3,10 @@ import numpy as np from tifffile import imread -from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS +from napari_cellseg3d.code_plugins.plugin_utilities import ( + UTILITIES_WIDGETS, + Utilities, +) def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index dc57b940..f2a9d32c 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,8 +1,7 @@ import os -import warnings +from functools import partial import numpy as np -import pytest import torch from napari_cellseg3d import utils @@ -33,6 +32,10 @@ def test_fill_list_in_between(): assert utils.fill_list_in_between(list, 2, "") == res + fill = partial(utils.fill_list_in_between, n=2, fill_value="") + + assert fill(list) == res + def test_align_array_sizes(): im = np.zeros((128, 512, 256)) @@ -79,15 +82,15 @@ def test_get_padding_dim(): tensor = torch.randn(2000, 30, 40) size = tensor.size() - warn = warnings.warn( - "Warning : a very large dimension for automatic padding has been computed.\n" - "Ensure your images are of an appropriate size and/or that you have enough memory." - "The padding value is currently 2048." - ) - - pad = utils.get_padding_dim(size) - - pytest.warns(warn, (lambda: utils.get_padding_dim(size))) + # warn = logger.warning( + # "Warning : a very large dimension for automatic padding has been computed.\n" + # "Ensure your images are of an appropriate size and/or that you have enough memory." + # "The padding value is currently 2048." + # ) + # + # pad = utils.get_padding_dim(size) + # + # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) assert pad == [2048, 32, 64] diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index 1bcb40d7..b9d4abe5 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,5 +1,7 @@ -from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR -from napari_cellseg3d.code_models.model_workers import WeightsDownloader +from napari_cellseg3d.code_models.model_workers import ( + PRETRAINED_WEIGHTS_DIR, + WeightsDownloader, +) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 13f489c7..fc1e0b90 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -16,15 +16,18 @@ try: import pydensecrf.densecrf as dcrf - from pydensecrf.utils import create_pairwise_bilateral - from pydensecrf.utils import create_pairwise_gaussian - from pydensecrf.utils import unary_from_softmax + from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, + ) CRF_INSTALLED = True except ImportError: warn( "pydensecrf not installed, CRF post-processing will not be available. " - "Please install by running pip install cellseg3d[crf]" + "Please install by running pip install cellseg3d[crf]", + stacklevel=1, ) CRF_INSTALLED = False diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index d541b486..37fc6a49 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import napari @@ -12,7 +11,6 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -135,11 +133,11 @@ def save_log(self): f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) else: - warnings.warn(f"No logger defined : Log is {self.log}") + logger.warning(f"No logger defined : Log is {self.log}") def save_log_to_path(self, path): """Saves the worker log to a specific path. Cannot be used with connect. @@ -161,7 +159,7 @@ def save_log_to_path(self, path): f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) @@ -170,7 +168,7 @@ def display_status_report(self): (usually when starting a worker)""" # if self.container_report is None or self.log is None: - # warnings.warn( + # logger.warning( # "Status report widget has been closed. Trying to re-instantiate..." # ) # self.container_report = QWidget() diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 2f10aa1f..d551920d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,11 +1,11 @@ from dataclasses import dataclass +from functools import partial from typing import List import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label -from skimage.measure import regionprops +from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed @@ -14,9 +14,8 @@ from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis # from skimage.measure import marching_cubes # from skimage.measure import mesh_surface_area @@ -399,8 +398,10 @@ def sphericity(region): volume = [region.area for region in properties] - def fill(lst, n=len(properties) - 1): - return fill_list_in_between(lst, n, "") + # def fill(lst, n=len(properties) - 1): + # return fill_list_in_between(lst, n, "") + + fill = partial(fill_list_in_between, n=len(properties) - 1, fill_value="") if len(volume_image.flatten()) != 0: ratio = fill([np.sum(volume) / len(volume_image.flatten())]) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index dce2f452..65b1c80a 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -42,8 +42,7 @@ # from napari.qt.threading import thread_worker # threads -from napari.qt.threading import GeneratorWorker -from napari.qt.threading import WorkerBaseSignals +from napari.qt.threading import GeneratorWorker, WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -51,8 +50,12 @@ from tqdm import tqdm # local -from napari_cellseg3d.code_models.model_instance_seg import ImageStats -from napari_cellseg3d.code_models.model_instance_seg import volume_stats +from napari_cellseg3d import config, utils +from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.model_instance_seg import ( + ImageStats, + volume_stats, +) logger = utils.LOGGER diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index 2ac0875d..004db3a1 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,9 +12,11 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import create_pairwise_bilateral -from pydensecrf.utils import create_pairwise_gaussian -from pydensecrf.utils import unary_from_softmax +from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, +) __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 2cb3581b..e7b97e01 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from pathlib import Path @@ -404,7 +403,7 @@ def load_dataset_paths(self): file_paths = sorted(Path(directory).glob("*" + filetype)) if len(file_paths) == 0: - warnings.warn( + logger.warning( f"The folder does not contain any compatible {filetype} files.\n" f"Please check the validity of the folder and images." ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index a847ebf7..0bff4cae 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,18 +1,18 @@ -import warnings from pathlib import Path import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import threshold -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceWidgets, + clear_small_objects, + threshold, + to_semantic, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -84,7 +84,7 @@ def show_result(viewer, layer, image, name): logger.debug("Added resulting label layer") viewer.add_labels(image, name=name) else: - warnings.warn( + logger.warning( f"Results not shown, unsupported layer type {type(layer)}" ) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 1647e858..d82df475 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import napari @@ -245,7 +244,7 @@ def _start(self): # maybe use singletons or make docked widgets attributes that are hidden upon opening if not self._check_ready(): - warnings.warn("Please select at least one valid layer !") + logger.warning("Please select at least one valid layer !") return # self._viewer.window.remove_dock_widget(self.parent()) # no need to close utils ? @@ -329,7 +328,7 @@ def add_isotropic_layer( self, layer, colormap="inferno", - contrast_lim=[200, 1000], # TODO generalize ? + contrast_lim=(200, 1000), # TODO generalize ? opacity=0.7, visible=True, ): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 522f91bb..eab16c8b 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,4 +1,3 @@ -import warnings from functools import partial import napari @@ -9,10 +8,16 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceMethod, + InstanceWidgets, +) +from napari_cellseg3d.code_models.model_workers import ( + InferenceResult, + InferenceWorker, +) + +logger = utils.LOGGER class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -539,7 +544,7 @@ def start(self): if not self._check_results_path(save_path): msg = f"ERROR: please set valid results path. Current path is {save_path}" self.log.print_and_log(msg) - warnings.warn(msg) + logger.warning(msg) else: if self.results_path is None: self.results_path = save_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 132c9531..88991f43 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,5 +1,4 @@ import shutil -import warnings from functools import partial from pathlib import Path @@ -418,8 +417,7 @@ def check_ready(self): if self.images_filepaths != [] and self.labels_filepaths != []: return True else: - warnings.formatwarning = utils.format_Warning - warnings.warn("Image and label paths are not correctly set") + logger.warning("Image and label paths are not correctly set") return False def _build(self): @@ -787,7 +785,7 @@ def start(self): if not self.check_ready(): # issues a warning if not ready err = "Aborting, please set all required paths" self.log.print_and_log(err) - warnings.warn(err) + logger.warning(err) return if self.worker is not None: @@ -1043,7 +1041,7 @@ def _make_csv(self): size_column = range(1, self.worker_config.max_epochs + 1) if len(self.loss_values) == 0 or self.loss_values is None: - warnings.warn("No loss values to add to csv !") + logger.warning("No loss values to add to csv !") return self.df = pd.DataFrame( diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index e3e05f6c..235595e4 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import matplotlib.pyplot as plt @@ -20,7 +19,6 @@ from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -180,10 +178,11 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: - warnings.warn( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None: + if cfg.image.shape != cfg.labels.shape: + logger.warning( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): @@ -237,7 +236,7 @@ def run_review(self): self._reset() previous_viewer.close() except ValueError as e: - warnings.warn( + logger.warning( f"An exception occurred : {e}. Please ensure you have entered all required parameters." ) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index c09c376f..8753a642 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,4 +1,3 @@ -import warnings from datetime import datetime, timedelta from pathlib import Path @@ -16,7 +15,7 @@ GUI_MINIMUM_HEIGHT = 300 TIMER_FORMAT = "%H:%M:%S" - +logger = utils.LOGGER """ plugin_dock.py ==================================== @@ -261,7 +260,7 @@ def update_dm(self, slice_num): def button_func(self): # updates csv every time you press button... if self.viewer.dims.ndisplay != 2: # TODO test if undefined behaviour or if okay - warnings.warn("Please switch back to 2D mode !") + logger.warning("Please switch back to 2D mode !") return self.update_time_csv() diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 45c0c119..462ee450 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,9 +2,7 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget # local import napari_cellseg3d.interface as ui diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 43f961f4..2b38eb29 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -1,5 +1,4 @@ import datetime -import warnings from dataclasses import dataclass from pathlib import Path from typing import List, Optional @@ -84,9 +83,9 @@ def get_model(self): return MODEL_LIST[self.name] except KeyError as e: msg = f"Model {self.name} is not defined" - warnings.warn(msg) logger.warning(msg) - raise KeyError(e) + logger.warning(msg) + raise KeyError from e @staticmethod def get_model_name_list(): diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..b4712aec 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -4,8 +4,7 @@ import numpy as np import scipy.ndimage as ndimage from skimage.filters import threshold_otsu -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from napari_cellseg3d.code_models.model_instance_seg import binary_watershed diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 479a07dd..641de627 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,8 +2,7 @@ import os import numpy as np -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" # output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2f079d09..2ab60332 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -8,8 +8,7 @@ import numpy as np import scipy.ndimage as ndimage from napari.qt.threading import thread_worker -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 9a100dc2..36fc9aab 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,5 +1,4 @@ import threading -import warnings from functools import partial from typing import List, Optional @@ -9,32 +8,30 @@ from qtpy import QtCore # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt -from qtpy.QtCore import QUrl -from qtpy.QtGui import QCursor -from qtpy.QtGui import QDesktopServices -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QCheckBox -from qtpy.QtWidgets import QComboBox -from qtpy.QtWidgets import QDoubleSpinBox -from qtpy.QtWidgets import QFileDialog -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QGroupBox -from qtpy.QtWidgets import QHBoxLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QMenu -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QRadioButton -from qtpy.QtWidgets import QScrollArea -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QSlider -from qtpy.QtWidgets import QSpinBox -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtCore import QObject, Qt, QUrl +from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLayout, + QLineEdit, + QMenu, + QPushButton, + QRadioButton, + QScrollArea, + QSizePolicy, + QSlider, + QSpinBox, + QTextEdit, + QVBoxLayout, + QWidget, +) # Local from napari_cellseg3d import utils @@ -288,10 +285,10 @@ def print_and_log(self, text, printing=True): self.lock.release() def warn(self, warning): - """Show warnings.warn from another thread""" + """Show logger.warning from another thread""" self.lock.acquire() try: - warnings.warn(warning) + logger.warning(warning) finally: self.lock.release() @@ -536,7 +533,7 @@ def _build_container(self): ) def _warn_outside_bounds(self, default): - warnings.warn( + logger.warning( f"Default value {default} was outside of the ({self.minimum()}:{self.maximum()}) range" ) @@ -581,7 +578,7 @@ def slider_value(self): try: return self.value() / self._divide_factor except ZeroDivisionError as e: - raise ZeroDivisionError( + raise ZeroDivisionError from ( f"Divide factor cannot be 0 for Slider : {e}" ) @@ -791,8 +788,8 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: - warnings.warn("Please select a valid layer !") - return None + logger.warning("Please select a valid layer !") + return return self._viewer.layers[self.layer_name()].data @@ -1188,7 +1185,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, - possible_paths: list = [], + possible_paths: list = (), filetype: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1212,7 +1209,7 @@ def open_file_dialog( def open_folder_dialog( widget, - possible_paths: list = [], + possible_paths: list = (), ): default_path = utils.parse_default_path(possible_paths) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 9fbe6d7a..f3fc09ba 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,5 +1,4 @@ import logging -import warnings from datetime import datetime from pathlib import Path @@ -234,7 +233,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # warnings.warn( + # logger.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) @@ -242,7 +241,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): pad = 2**n n += 1 if pad >= 256: - warnings.warn( + LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." f"The padding value is currently {pad}." @@ -342,14 +341,14 @@ def annotation_to_input(label_ermito): # pass -def fill_list_in_between(lst, n, elem): +def fill_list_in_between(lst, n, fill_value): """Fills a list with n * elem between each member of list. Example with list = [1,2,3], n=2, elem='&' : returns [1, &, &,2,&,&,3,&,&] Args: lst: list to fill n: number of elements to add - elem: added n times after each element of list + fill_value: added n times after each element of list Returns : Filled list @@ -358,13 +357,13 @@ def fill_list_in_between(lst, n, elem): for i in range(len(lst)): temp_list = [lst[i]] while len(temp_list) < n + 1: - temp_list.append(elem) + temp_list.append(fill_value) if i < len(lst) - 1: new_list += temp_list else: new_list.append(lst[i]) for _j in range(n): - new_list.append(elem) + new_list.append(fill_value) return new_list return None @@ -535,26 +534,26 @@ def select_train_data(dataframe, ori_imgs, label_imgs, ori_filenames): return np.array(train_ori_imgs), np.array(train_label_imgs) -def format_Warning(message, category, filename, lineno, line=""): - """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` - - Args: - message: warning message - category: which type of warning has been raised - filename: file - lineno: line number - line: unused - - Returns: format - - """ - return ( - str(filename) - + ":" - + str(lineno) - + ": " - + category.__name__ - + ": " - + str(message) - + "\n" - ) +# def format_Warning(message, category, filename, lineno, line=""): +# """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` +# +# Args: +# message: warning message +# category: which type of warning has been raised +# filename: file +# lineno: line number +# line: unused +# +# Returns: format +# +# """ +# return ( +# str(filename) +# + ":" +# + str(lineno) +# + ": " +# + category.__name__ +# + ": " +# + str(message) +# + "\n" +# ) diff --git a/pyproject.toml b/pyproject.toml index d9a46ccf..8e7187f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,12 @@ where = ["."] "*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] [tool.ruff] -# Never enforce `E501` (line length violations). +select = [ + "E", "F", "W", + "I", + "B", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) ignore = ["E501", "E741"] [tool.black] From 6889b0de370fabbf6aa8cdebf4fd8cb1dbd45894 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:29:42 +0200 Subject: [PATCH 236/577] Functional CRF --- napari_cellseg3d/_tests/test_models.py | 39 +++ napari_cellseg3d/code_models/crf.py | 98 ++++++- .../code_models/model_instance_seg.py | 6 +- napari_cellseg3d/code_models/model_workers.py | 48 +++- napari_cellseg3d/code_plugins/plugin_base.py | 23 +- .../code_plugins/plugin_convert.py | 194 ++++--------- napari_cellseg3d/code_plugins/plugin_crf.py | 262 ++++++++++++++++++ napari_cellseg3d/code_plugins/plugin_crop.py | 7 +- .../code_plugins/plugin_model_inference.py | 32 ++- .../code_plugins/plugin_utilities.py | 15 +- napari_cellseg3d/config.py | 16 ++ napari_cellseg3d/interface.py | 19 +- napari_cellseg3d/utils.py | 81 +++++- 13 files changed, 671 insertions(+), 169 deletions(-) create mode 100644 napari_cellseg3d/code_plugins/plugin_crf.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 9280b230..1fc15872 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,9 +1,18 @@ +import numpy as np import torch +from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +def test_correct_shape_for_crf(): + test = np.random.rand(1, 1, 8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + test = np.random.rand(8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + + def test_model_list(): for model_name in MODEL_LIST.keys(): dims = 128 @@ -31,3 +40,33 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) # assert res > 0 + + +def test_crf(qtbot): + dims = 8 + mock_image = np.random.rand(1, dims, dims, dims) + mock_label = np.random.rand(2, dims, dims, dims) + + crf = CRFWorker(mock_image, mock_label) + + def on_yield(result): + assert isinstance(result, np.ndarray) + assert result.shape[-3:] == mock_label.shape[-3:] + + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) + + mock_image = mock_image[0] + mock_label = mock_label[0] + + crf = CRFWorker(mock_image, mock_label) + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index fc1e0b90..a0146a5e 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -9,11 +9,8 @@ Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. """ - from warnings import warn -import numpy as np - try: import pydensecrf.densecrf as dcrf from pydensecrf.utils import ( @@ -31,6 +28,12 @@ ) CRF_INSTALLED = False + +import numpy as np +from napari.qt.threading import GeneratorWorker + +from napari_cellseg3d.config import CRFConfig + __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ "Yves Paychère", @@ -49,6 +52,16 @@ ] +def correct_shape_for_crf(image): + if len(image.shape) == 4: + return image + if len(image.shape) > 4: + image = np.squeeze(image, axis=0) + if len(image.shape) < 4: + image = np.expand_dims(image, axis=0) + return correct_shape_for_crf(image) + + def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): """CRF post-processing step for the W-Net, applied to a batch of images. @@ -62,6 +75,8 @@ def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): Returns: np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. """ + if not CRF_INSTALLED: + return None return np.stack( [ @@ -83,10 +98,16 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + w1 (float): weight of the appearance/bilateral kernel. + w2 (float): weight of the smoothness/gaussian kernel. Returns: np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. """ + + if not CRF_INSTALLED: + return None + d = dcrf.DenseCRF( image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] ) @@ -123,3 +144,74 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): return np.array(Q).reshape( (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) ) + + +def crf_with_config(image, prob, config: CRFConfig = None): + if config is None: + config = CRFConfig() + if image.shape[-3:] != prob.shape[-3:]: + raise ValueError( + f"Image and probability shapes do not match: {image.shape} vs {prob.shape}" + f" (expected {image.shape[-3:]} == {prob.shape[-3:]})" + ) + + image = correct_shape_for_crf(image) + + return crf( + image, + prob, + config.sa, + config.sb, + config.sg, + config.w1, + config.w2, + config.n_iters, + ) + + +class CRFWorker(GeneratorWorker): + """Worker for the CRF post-processing step for the W-Net.""" + + def __init__( + self, + images_list, + labels_list, + config: CRFConfig = None, + log=None, + ): + super().__init__(self._run_crf_job) + + self.images = images_list + self.labels = labels_list + if config is None: + self.config = CRFConfig() + else: + self.config = config + self.log = log + + # TODO(cyril) : add progress bar into log ? or do it in inference + def _run_crf_job(self): + """Runs the CRF post-processing step for the W-Net.""" + if not CRF_INSTALLED: + raise ImportError("pydensecrf is not installed.") + + for image, labels in zip(self.images, self.labels): + if len(image.shape) == 3: + image = np.expand_dims(image, axis=0) + + if len(labels.shape) == 3: + labels = np.expand_dims(labels, axis=0) + + if image.shape[-3:] != labels.shape[-3:]: + raise ValueError("Image and labels must have the same shape.") + + yield crf( + image, + labels, + self.config.sa, + self.config.sb, + self.config.sg, + self.config.w1, + self.config.w2, + n_iter=self.config.n_iters, + ) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index d551920d..d1a03eec 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -65,7 +65,7 @@ def __init__( 1, divide_factor=100, text_label="", - parent=None, + parent=widget_parent, ), ) self.sliders.append(getattr(self, widget)) @@ -76,7 +76,9 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(text_label="", parent=None), + ui.DoubleIncrementCounter( + text_label="", parent=widget_parent + ), ) self.counters.append(getattr(self, widget)) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 65b1c80a..c7196db7 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -52,6 +52,7 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import crf_with_config from napari_cellseg3d.code_models.model_instance_seg import ( ImageStats, volume_stats, @@ -201,6 +202,7 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None + crf_results: np.array = None stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -528,7 +530,8 @@ def create_inference_result( self, semantic_labels, instance_labels, - from_layer: bool, + crf_results=None, + from_layer: bool = False, original=None, stats=None, i=0, @@ -543,15 +546,19 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - total_dims = len(semantic_labels.shape) - 3 + extra_dims = len(semantic_labels.shape) - 3 semantic_labels = np.swapaxes( - semantic_labels, 0 + total_dims, 2 + total_dims + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims ) return InferenceResult( image_id=i + 1, original=original, instance_labels=instance_labels, + crf_results=crf_results, stats=stats, result=semantic_labels, model_name=self.config.model_info.name, @@ -584,6 +591,7 @@ def save_image( image, from_layer=False, i=0, + additional_info="", ): if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" @@ -597,7 +605,7 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"Prediction_{i+1}" + + f"{additional_info}_Prediction_{i+1}" + original_filename + self.config.model_info.name + f"_{time}_" @@ -679,6 +687,15 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): self.save_image(out, i=i) instance_labels, stats = self.get_instance_result(out, i=i) + if self.config.use_crf: + try: + crf_results = self.run_crf(inputs, out, image_id=i) + + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + crf_results = None + else: + crf_results = None original = np.array(inf_data["image"]).astype(np.float32) @@ -687,12 +704,29 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): return self.create_inference_result( out, instance_labels, + crf_results, from_layer=False, original=original, stats=stats, i=i, ) + def run_crf(self, image, labels, image_id=0): + self.log(f"IMAGE SHAPE : {image.shape}") + self.log(f"LABEL SHAPE : {labels.shape}") + + try: + crf_results = crf_with_config( + image, labels, config=self.config.crf_config + ) + self.save_image( + crf_results, i=image_id, additional_info="CRF", from_layer=True + ) + return crf_results + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + return None + def stats_csv(self, instance_labels): if self.config.compute_stats: stats = volume_stats(instance_labels) @@ -729,9 +763,15 @@ def inference_on_layer(self, image, model, post_process_transforms): instance_labels_results.append(instance_labels) stats_results.append(stats) + if self.config.use_crf: + crf_results = self.run_crf(image, out) + else: + crf_results = None + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels_results, + crf_results=crf_results, from_layer=True, stats=stats_results, ) diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index e7b97e01..26da7a42 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -46,15 +46,15 @@ def __init__( self.image_path = None """str: path to image folder""" - self.show_image_io = loads_images + self._show_image_io = loads_images self.label_path = None """str: path to label folder""" - self.show_label_io = loads_labels + self._show_label_io = loads_labels self.results_path = None """str: path to results folder""" - self.show_results_io = has_results + self._show_results_io = has_results self._default_path = [self.image_path, self.label_path] @@ -117,7 +117,6 @@ def show_menu(_, event): def _build_io_panel(self): self.io_panel = ui.GroupedWidget("Data") self.save_label = ui.make_label("Save location :", parent=self) - # self.io_panel.setToolTip("IO Panel") ui.add_widgets( @@ -139,25 +138,25 @@ def _build_io_panel(self): return self.io_panel def _remove_unused(self): - if not self.show_label_io: + if not self._show_label_io: self.labels_filewidget = None self.label_layer_loader = None - if not self.show_image_io: + if not self._show_image_io: self.image_layer_loader = None self.image_filewidget = None - if not self.show_results_io: + if not self._show_results_io: self.results_filewidget = None def _set_io_visibility(self): ################## # Show when layer is selected - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_layer_loader, self.layer_choice) else: self._hide_io_element(self.image_layer_loader) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.label_layer_loader, self.layer_choice) else: self._hide_io_element(self.label_layer_loader) @@ -167,15 +166,15 @@ def _set_io_visibility(self): f = self.folder_choice self._show_io_element(self.filetype_choice, f) - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_filewidget, f) else: self._hide_io_element(self.image_filewidget) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.labels_filewidget, f) else: self._hide_io_element(self.labels_filewidget) - if not self.show_results_io: + if not self._show_results_io: self._hide_io_element(self.results_filewidget) self.folder_choice.toggle() diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 0bff4cae..f7b476d0 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -3,7 +3,7 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread import napari_cellseg3d.interface as ui from napari_cellseg3d import utils @@ -15,80 +15,12 @@ ) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -# TODO break down into multiple mini-widgets -# TODO create parent class for utils modules to avoid duplicates - -MAX_W = 200 -MAX_H = 1000 +MAX_W = ui.UTILS_MAX_WIDTH +MAX_H = ui.UTILS_MAX_HEIGHT logger = utils.LOGGER -def save_folder(results_path, folder_name, images, image_paths): - """ - Saves a list of images in a folder - - Args: - results_path: Path to the folder containing results - folder_name: Name of the folder containing results - images: List of images to save - image_paths: list of filenames of images - """ - results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False, parents=True) - - for file, image in zip(image_paths, images): - path = results_folder / Path(file).name - - imwrite( - path, - image, - ) - logger.info(f"Saved processed folder as : {results_folder}") - - -def save_layer(results_path, image_name, image): - """ - Saves an image layer at the specified path - - Args: - results_path: path to folder containing result - image_name: image name for saving - image: data array containing image - - Returns: - - """ - path = str(results_path / Path(image_name)) # TODO flexible filetype - logger.info(f"Saved as : {path}") - imwrite(path, image) - - -def show_result(viewer, layer, image, name): - """ - Adds layers to a viewer to show result to user - - Args: - viewer: viewer to add layer in - layer: type of the original layer the operation was run on, to determine whether it should be an Image or Labels layer - image: the data array containing the image - name: name of the added layer - - Returns: - - """ - if isinstance(layer, napari.layers.Image): - logger.debug("Added resulting image layer") - viewer.add_image(image, name=name) - elif isinstance(layer, napari.layers.Labels): - logger.debug("Added resulting label layer") - viewer.add_labels(image, name=name) - else: - logger.warning( - f"Results not shown, unsupported layer type {type(layer)}" - ) - - class AnisoUtils(BasePluginFolder): """Class to correct anisotropy in images""" @@ -154,31 +86,30 @@ def _start(self): data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) - save_layer( + utils.save_layer( self.results_path, f"isotropic_{layer.name}_{utils.get_date_time()}.tif", isotropic_image, ) - show_result( + utils.show_result( self._viewer, layer, isotropic_image, f"isotropic_{layer.name}", ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -254,27 +185,26 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"cleared_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -336,12 +266,12 @@ def _start(self): data = np.array(layer.data) semantic = to_semantic(data) - save_layer( + utils.save_layer( self.results_path, f"semantic_{layer.name}_{utils.get_date_time()}.tif", semantic, ) - show_result( + utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) elif self.folder_choice.isChecked(): @@ -350,7 +280,7 @@ def _start(self): to_semantic(file, is_file_path=True) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"semantic_results_{utils.get_date_time()}", images, @@ -421,7 +351,7 @@ def _start(self): data = np.array(layer.data) instance = self.instance_widgets.run_method(data) - save_layer( + utils.save_layer( self.results_path, f"instance_{layer.name}_{utils.get_date_time()}.tif", instance, @@ -430,19 +360,18 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - self.instance_widgets.run_method(imread(file)) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.instance_widgets.run_method(imread(file)) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -517,27 +446,26 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"threshold_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py new file mode 100644 index 00000000..3dbd47bb --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -0,0 +1,262 @@ +from functools import partial +from pathlib import Path + +import napari.layers +from qtpy.QtWidgets import QSizePolicy +from tqdm import tqdm + +from napari_cellseg3d import config, utils +from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage +from napari_cellseg3d.utils import LOGGER as logger + + +# TODO add CRF on folder +class CRFParamsWidget(ui.GroupedWidget): + """Use this widget when adding the crf as part of another widget (rather than a standalone widget)""" + + def __init__(self, parent=None): + super().__init__(title="CRF parameters", parent=parent) + ####### + # CRF params # + self.sa_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Alpha std" + ) + self.sb_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Beta std" + ) + self.sg_choice = ui.DoubleIncrementCounter( + default=1, parent=self, text_label="Gamma std" + ) + self.w1_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Weight appearance" + ) + self.w2_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Weight smoothness" + ) + self.n_iter_choice = ui.IntIncrementCounter( + default=5, parent=self, text_label="Number of iterations" + ) + ####### + self._build() + self._set_tooltips() + + def _build(self): + ui.add_widgets( + self.layout, + [ + # self.sa_choice.label, + self.sa_choice, + # self.sb_choice.label, + self.sb_choice, + # self.sg_choice.label, + self.sg_choice, + # self.w1_choice.label, + self.w1_choice, + # self.w2_choice.label, + self.w2_choice, + # self.n_iter_choice.label, + self.n_iter_choice, + ], + ) + self.set_layout() + + def _set_tooltips(self): + self.sa_choice.setToolTip( + "SA : Standard deviation of the Gaussian kernel in the appearance term." + ) + self.sb_choice.setToolTip( + "SB : Standard deviation of the Gaussian kernel in the smoothness term." + ) + self.sg_choice.setToolTip( + "SG : Standard deviation of the Gaussian kernel in the gradient term." + ) + self.w1_choice.setToolTip( + "W1 : Weight of the appearance term in the CRF." + ) + self.w2_choice.setToolTip( + "W2 : Weight of the smoothness term in the CRF." + ) + self.n_iter_choice.setToolTip("Number of iterations of the CRF.") + + def make_config(self): + return config.CRFConfig( + sa=self.sa_choice.value(), + sb=self.sb_choice.value(), + sg=self.sg_choice.value(), + w1=self.w1_choice.value(), + w2=self.w2_choice.value(), + n_iters=self.n_iter_choice.value(), + ) + + +class CRFWidget(BasePluginSingleImage): + def __init__(self, viewer, parent=None): + """ + Create a widget for CRF post-processing. + Args: + viewer: napari viewer to display the widget + parent: parent widget. Defaults to None. + """ + super().__init__(viewer, parent) + self._viewer = viewer + + self.start_button = ui.Button("Start", self._start, parent=self) + self.crf_params_widget = CRFParamsWidget(parent=self) + self.io_panel = self._build_io_panel() + self.io_panel.setVisible(False) + + self.results_filewidget.setVisible(True) + self.label_layer_loader.setVisible(True) + self.label_layer_loader.set_layer_type( + napari.layers.Image + ) # to load all crf-compatible inputs, not int only + self.image_layer_loader.setVisible(True) + self.start_button.setVisible(True) + + self.result_layer = None + self.result_name = None + self.crf_results = [] + + self.results_path = Path.home() / Path("cellseg3d/crf") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._container = ui.ContainerWidget(parent=self, l=11, t=11, r=11) + self.layout = self._container.layout + + self._build() + + self.worker = None + self.log = None + + def _build(self): + self.setMinimumWidth(100) + ui.add_widgets( + self.layout, + [ + self.image_layer_loader, + self.label_layer_loader, + self.save_label, + self.results_filewidget, + ui.make_label(""), + self.crf_params_widget, + ui.make_label(""), + self.start_button, + ], + ) + # self.io_panel.setLayout(self.io_panel.layout) + self.setLayout(self.layout) + + ui.ScrollArea.make_scrollable( + self.layout, self, max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT] + ) + self._container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + return self._container + + def make_config(self): + return self.crf_params_widget.make_config() + + def _check_ready(self): + if len(self.label_layer_loader.layer_list) < 1: + logger.warning("No label layer loaded") + return False + if len(self.image_layer_loader.layer_list) < 1: + logger.warning("No image layer loaded") + return False + + if len(self.label_layer_loader.layer_data().shape) < 3: + logger.warning("Label layer must be 3D") + return False + if len(self.image_layer_loader.layer_data().shape) < 3: + logger.warning("Image layer must be 3D") + return False + if ( + self.label_layer_loader.layer_data().shape[-3:] + != self.image_layer_loader.layer_data().shape[-3:] + ): + logger.warning("Image and label layers must have the same shape!") + return False + + return True + + def run_crf_on_batch(self, images_list: list, labels_list: list, log=None): + self.crf_results = [] + for image, label in zip(images_list, labels_list): + tqdm( + unit="B", + total=len(images_list), + position=0, + file=log, + ) + result = crf_with_config(image, label, self.make_config()) + self.crf_results.append(result) + return self.crf_results + + def _prepare_worker(self, images_list: list, labels_list: list): + self.worker = CRFWorker( + images_list=images_list, + labels_list=labels_list, + config=self.make_config(), + ) + + self.worker.started.connect(self._on_start) + self.worker.yielded.connect(partial(self._on_yield)) + self.worker.errored.connect(partial(self._on_error)) + self.worker.finished.connect(self._on_finish) + + def _start(self): + if not self._check_ready(): + return + + self.result_layer = self.label_layer_loader.layer() + self.result_name = self.label_layer_loader.layer_name() + + self.results_path.mkdir(exist_ok=True, parents=True) + + image_list = [self.image_layer_loader.layer_data()] + labels_list = [self.label_layer_loader.layer_data()] + [logger.debug(f"Image shape: {image.shape}") for image in image_list] + [ + logger.debug(f"Label shape: {labels.shape}") + for labels in labels_list + ] + + self._prepare_worker(image_list, labels_list) + + if self.worker.is_running: # if worker is running, tries to stop + logger.info("Stop request, waiting for previous job to finish") + self.start_button.setText("Stopping...") + self.worker.quit() + else: # once worker is started, update buttons + self.start_button.setText("Running...") + logger.info("Starting CRF...") + self.worker.start() + + def _on_yield(self, result): + self.crf_results.append(result) + + utils.save_layer( + self.results_filewidget.text_field.text(), + str(self.result_name + "_crf.tif"), + result, + ) + self._viewer.add_image( + result, + name="crf_" + self.result_name, + ) + + def _on_start(self): + self.crf_results = [] + + def _on_finish(self): + self.worker = None + + def _on_error(self, error): + logger.error(error) + self.start_button.setText("Start") + self.worker.quit() + self.worker = None diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index d82df475..6e7f91f3 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -174,7 +174,12 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + ui.ScrollArea.make_scrollable( + layout, + self, + max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT], + min_wh=[200, 200], + ) self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index eab16c8b..472cccd8 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -16,6 +16,7 @@ InferenceResult, InferenceWorker, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFParamsWidget logger = utils.LOGGER @@ -195,9 +196,17 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## # instance segmentation widgets self.instance_widgets = InstanceWidgets(parent=self) + self.crf_widgets = CRFParamsWidget(parent=self) self.use_instance_choice = ui.CheckBox( - "Run instance segmentation", func=self._toggle_display_instance + "Run instance segmentation", + func=self._toggle_display_instance, + parent=self, + ) + self.use_crf = ui.CheckBox( + "Use CRF post-processing", + func=self._toggle_display_crf, + parent=self, ) self.save_stats_to_csv_box = ui.CheckBox( @@ -309,6 +318,10 @@ def _toggle_display_thresh(self): self.thresholding_checkbox, self.thresholding_slider.container ) + def _toggle_display_crf(self): + """Shows the choices for CRF post-processing depending on whether :py:attr:`self.use_crf` is checked""" + ui.toggle_visibility(self.use_crf, self.crf_widgets) + def _toggle_display_instance(self): """Shows or hides the options for instance segmentation based on current user selection""" ui.toggle_visibility(self.use_instance_choice, self.instance_widgets) @@ -426,6 +439,8 @@ def _build(self): self.thresholding_slider.container, # thresholding self.use_instance_choice, self.instance_widgets, + self.use_crf, + self.crf_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -437,6 +452,7 @@ def _build(self): self.anisotropy_wdgt.container.setVisible(False) self.thresholding_slider.container.setVisible(False) self.instance_widgets.setVisible(False) + self.crf_widgets.setVisible(False) self.save_stats_to_csv_box.setVisible(False) post_proc_group.setLayout(post_proc_layout) @@ -588,6 +604,8 @@ def start(self): compute_stats=self.save_stats_to_csv_box.isChecked(), post_process_config=self.post_process_config, sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), ) ##################### ##################### @@ -737,7 +755,10 @@ def on_yield(self, result: InferenceResult): opacity=0.8, ) - if result.instance_labels is not None: + if ( + len(result.instance_labels) > 0 + and self.worker_config.post_process_config.instance.enabled + ): for i, labels in enumerate(result.instance_labels): # labels = result.instance_labels method_name = ( @@ -779,5 +800,12 @@ def on_yield(self, result: InferenceResult): # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) + + if result.crf_results is not None: + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 462ee450..868dd279 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -13,6 +13,7 @@ ToInstanceUtils, ToSemanticUtils, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFWidget from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { @@ -22,6 +23,7 @@ "Convert to instance labels": ToInstanceUtils, "Convert to semantic labels": ToSemanticUtils, "Threshold": ThresholdUtils, + "CRF": CRFWidget, } @@ -30,7 +32,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): super().__init__() self._viewer = viewer - attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh"] + attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh", "crf"] self._create_utils_widgets(attr_names) # self.crop = Cropping(self._viewer) @@ -54,8 +56,15 @@ def __init__(self, viewer: "napari.viewer.Viewer"): def _build(self): layout = QVBoxLayout() ui.add_widgets(layout, self.utils_widgets) - layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) - layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) + ui.GroupedWidget.create_single_widget_group( + "Utilities", + widget=self.utils_choice, + layout=layout, + alignment=ui.BOTT_AL, + ) + + # layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) + # layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 2b38eb29..8a7c1565 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -137,6 +137,20 @@ class PostProcessConfig: instance: InstanceSegConfig = InstanceSegConfig() +@dataclass +class CRFConfig: + """ + Class to record params for CRF + """ + + sa: float = 10 + sb: float = 5 + sg: float = 1 + w1: float = 10 + w2: float = 5 + n_iters: int = 5 + + ################ # Inference configs @@ -196,6 +210,8 @@ class InferenceWorkerConfig: compute_stats: bool = False post_process_config: PostProcessConfig = PostProcessConfig() sliding_window_config: SlidingWindowConfig = SlidingWindowConfig() + use_crf: bool = False + crf_config: CRFConfig = CRFConfig() images_filepaths: str = None layer: napari.layers.Layer = None diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 36fc9aab..55e5abb3 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -57,6 +57,8 @@ """Alias for Qt.AlignmentFlag.AlignAbsolute, to use in addWidget""" BOTT_AL = Qt.AlignmentFlag.AlignBottom """Alias for Qt.AlignmentFlag.AlignBottom, to use in addWidget""" +TOP_AL = Qt.AlignmentFlag.AlignTop +"""Alias for Qt.AlignmentFlag.AlignTop, to use in addWidget""" ############### # colors dark_red = "#72071d" # crimson red @@ -65,6 +67,9 @@ napari_param_grey = "#414851" # napari parameters menu color (lighter gray) napari_param_darkgrey = "#202228" # napari default LineEdit color ############### +# dimensions for utils ScrollArea +UTILS_MAX_WIDTH = 300 +UTILS_MAX_HEIGHT = 500 logger = utils.LOGGER @@ -791,7 +796,7 @@ def layer_data(self): logger.warning("Please select a valid layer !") return - return self._viewer.layers[self.layer_name()].data + return self.layer().data class FilePathWidget(QWidget): # TODO include load as folder @@ -1277,12 +1282,20 @@ def set_layout(self): @classmethod def create_single_widget_group( - cls, title, widget, layout, l=7, t=20, r=7, b=11 + cls, + title, + widget, + layout, + l=7, + t=20, + r=7, + b=11, + alignment=LEFT_AL, ): group = cls(title, l, t, r, b) group.layout.addWidget(widget) group.setLayout(group.layout) - layout.addWidget(group) + layout.addWidget(group, alignment=alignment) def add_widgets(layout, widgets, alignment=LEFT_AL): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index f3fc09ba..e7eaf95a 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,11 +2,12 @@ from datetime import datetime from pathlib import Path +import napari import numpy as np from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian -from tifffile import imread as tfl_imread +from tifffile import imread, imwrite LOGGER = logging.getLogger(__name__) ############### @@ -21,6 +22,76 @@ """ +#################### +# viewer utils +def save_folder(results_path, folder_name, images, image_paths): + """ + Saves a list of images in a folder + + Args: + results_path: Path to the folder containing results + folder_name: Name of the folder containing results + images: List of images to save + image_paths: list of filenames of images + """ + results_folder = results_path / Path(folder_name) + results_folder.mkdir(exist_ok=False, parents=True) + + for file, image in zip(image_paths, images): + path = results_folder / Path(file).name + + imwrite( + path, + image, + ) + LOGGER.info(f"Saved processed folder as : {results_folder}") + + +def save_layer(results_path, image_name, image): + """ + Saves an image layer at the specified path + + Args: + results_path: path to folder containing result + image_name: image name for saving + image: data array containing image + + Returns: + + """ + path = str(results_path / Path(image_name)) # TODO flexible filetype + LOGGER.info(f"Saved as : {path}") + imwrite(path, image) + + +def show_result(viewer, layer, image, name): + """ + Adds layers to a viewer to show result to user + + Args: + viewer: viewer to add layer in + layer: original layer the operation was run on, to determine whether it should be an Image or Labels layer + image: the data array containing the image + name: name of the added layer + + Returns: + + """ + if isinstance(layer, napari.layers.Image): + LOGGER.debug("Added resulting image layer") + viewer.add_image(image, name=name) + elif isinstance(layer, napari.layers.Labels): + LOGGER.debug("Added resulting label layer") + viewer.add_labels(image, name=name) + else: + LOGGER.warning( + f"Results not shown, unsupported layer type {type(layer)}" + ) + + +#################### + + class Singleton(type): """ Singleton class that can only be instantiated once at a time, @@ -44,7 +115,7 @@ def __call__(cls, *args, **kwargs): # if filename == "tif": # return True # def read(self, data, **kwargs): -# return tfl_imread(data) +# return imread(data) # # def get_data(self, data): # return data, {} @@ -233,7 +304,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # logger.warning( + # LOGGER.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) @@ -470,9 +541,7 @@ def load_images( ) # images_original = dask_imread(filename_pattern_original) else: - images_original = tfl_imread( - filename_pattern_original - ) # tifffile imread + images_original = imread(filename_pattern_original) # tifffile imread return images_original From c0d9daf23f04fd37a094771fe51a184a5735d635 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:37:33 +0200 Subject: [PATCH 237/577] Fix erroneous test comment, added toggle for crf - Warn if crf not installed - Fix test --- napari_cellseg3d/_tests/test_utils.py | 2 +- napari_cellseg3d/code_plugins/plugin_crf.py | 22 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index f2a9d32c..0b28183d 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -88,7 +88,7 @@ def test_get_padding_dim(): # "The padding value is currently 2048." # ) # - # pad = utils.get_padding_dim(size) + pad = utils.get_padding_dim(size) # # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 3dbd47bb..cbdacf3a 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -7,7 +7,11 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config +from napari_cellseg3d.code_models.crf import ( + CRF_INSTALLED, + CRFWorker, + crf_with_config, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.utils import LOGGER as logger @@ -43,6 +47,17 @@ def __init__(self, parent=None): self._set_tooltips() def _build(self): + if not CRF_INSTALLED: + ui.add_widgets( + self.layout, + [ + ui.make_label( + "ERROR: CRF not installed.\nPlease refer to the documentation to install it." + ), + ], + ) + self.set_layout() + return ui.add_widgets( self.layout, [ @@ -113,7 +128,10 @@ def __init__(self, viewer, parent=None): napari.layers.Image ) # to load all crf-compatible inputs, not int only self.image_layer_loader.setVisible(True) - self.start_button.setVisible(True) + if CRF_INSTALLED: + self.start_button.setVisible(True) + else: + self.start_button.setVisible(False) self.result_layer = None self.result_name = None From 2b0f639bb915ddb62d79d6340c60d60c8e06c48d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:56:08 +0200 Subject: [PATCH 238/577] Specify missing test deps --- pyproject.toml | 3 ++- tox.ini | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8e7187f5..5648ab40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ -"git+https://github.com/lucasb-eyer/pydensecrf.git", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -81,4 +81,5 @@ test = [ "coverage", "tox", "twine", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 87338cd8..a3eef589 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = magicgui pytest-qt qtpy + "git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From f7aa2cc79f720b70dd3636fccffbc83b56a563e9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:02:31 +0200 Subject: [PATCH 239/577] Trying to fix deps on Git --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5648ab40..73fc862c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] diff --git a/tox.ini b/tox.ini index a3eef589..65a49bdd 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - "git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From e7b51a27c743217eaad5127d9b706145d2c63023 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:04:33 +0200 Subject: [PATCH 240/577] Removed master link to pydensecrf --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73fc862c..8d9d6bf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 65a49bdd..6f71b9db 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 47faea5e3590d9a69b2358bb88a5dbb1ee67f9ff Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:07:23 +0200 Subject: [PATCH 241/577] Use commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d9d6bf4..0cc237e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] diff --git a/tox.ini b/tox.ini index 6f71b9db..5e0777f3 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 6c580e3ca382603bd67c54aede48594efae05994 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:09:27 +0200 Subject: [PATCH 242/577] Removed commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0cc237e5..09ed8585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] diff --git a/tox.ini b/tox.ini index 5e0777f3..3d7df5d0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From a65b1c5a00bb964e174f0f66b10c16dd70a49053 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:11:27 +0200 Subject: [PATCH 243/577] Removed master --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 09ed8585..db39904b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] diff --git a/tox.ini b/tox.ini index 3d7df5d0..fd92727c 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From f37a9fb4a5cf387ae6adab7380838ec05aca8705 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:17:16 +0200 Subject: [PATCH 244/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index fd92727c..0a7c07f0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" + pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 66518886d2ed32347d28046bdc9cd7abaac94060 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 09:06:23 +0200 Subject: [PATCH 245/577] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db39904b..d223072a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] From 9265555cb4c738fccdb3d657fd6242fcb1923301 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 17:41:05 +0200 Subject: [PATCH 246/577] Fixes and improvements - More CRF info - Added error handling to scheduler rate - Added ETA to training - Updated padding warning trigger size --- napari_cellseg3d/code_models/crf.py | 30 ++++++++++------ napari_cellseg3d/code_models/model_workers.py | 34 ++++++++++++++----- .../code_models/models/model_VNet.py | 2 +- napari_cellseg3d/code_plugins/plugin_crf.py | 6 ++++ .../code_plugins/plugin_model_inference.py | 3 ++ .../code_plugins/plugin_model_training.py | 6 ++-- napari_cellseg3d/utils.py | 6 ++-- 7 files changed, 61 insertions(+), 26 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index a0146a5e..1b8dce28 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -33,6 +33,7 @@ from napari.qt.threading import GeneratorWorker from napari_cellseg3d.config import CRFConfig +from napari_cellseg3d.utils import LOGGER as logger __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ @@ -52,12 +53,16 @@ ] -def correct_shape_for_crf(image): - if len(image.shape) == 4: +def correct_shape_for_crf(image, desired_dims=4): + if len(image.shape) == desired_dims: return image - if len(image.shape) > 4: + if len(image.shape) > desired_dims: + if image.shape[0] > 1: + raise ValueError( + f"Image shape {image.shape} might have several channels" + ) image = np.squeeze(image, axis=0) - if len(image.shape) < 4: + if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) return correct_shape_for_crf(image) @@ -146,7 +151,7 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): ) -def crf_with_config(image, prob, config: CRFConfig = None): +def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info): if config is None: config = CRFConfig() if image.shape[-3:] != prob.shape[-3:]: @@ -156,6 +161,12 @@ def crf_with_config(image, prob, config: CRFConfig = None): ) image = correct_shape_for_crf(image) + prob = correct_shape_for_crf(prob) + + if log is not None: + log("Running CRF post-processing step") + log(f"Image shape : {image.shape}") + log(f"Labels shape : {prob.shape}") return crf( image, @@ -196,15 +207,12 @@ def _run_crf_job(self): raise ImportError("pydensecrf is not installed.") for image, labels in zip(self.images, self.labels): - if len(image.shape) == 3: - image = np.expand_dims(image, axis=0) - - if len(labels.shape) == 3: - labels = np.expand_dims(labels, axis=0) - if image.shape[-3:] != labels.shape[-3:]: raise ValueError("Image and labels must have the same shape.") + image = correct_shape_for_crf(image) + labels = correct_shape_for_crf(labels) + yield crf( image, labels, diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index c7196db7..39e6bb91 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1,4 +1,5 @@ import platform +import time import typing as t from dataclasses import dataclass from math import ceil @@ -598,7 +599,7 @@ def save_image( filetype = self.config.filetype else: original_filename = "_" - filetype = "" + filetype = ".tif" time = utils.get_date_time() @@ -712,12 +713,9 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): ) def run_crf(self, image, labels, image_id=0): - self.log(f"IMAGE SHAPE : {image.shape}") - self.log(f"LABEL SHAPE : {labels.shape}") - try: crf_results = crf_with_config( - image, labels, config=self.config.crf_config + image, labels, config=self.config.crf_config, log=self.log ) self.save_image( crf_results, i=image_id, additional_info="CRF", from_layer=True @@ -1152,6 +1150,8 @@ def train(self): weights_config = self.config.weights_info deterministic_config = self.config.deterministic_config + start_time = time.time() + try: if deterministic_config.enabled: set_determinism( @@ -1364,14 +1364,23 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) + + factor = self.config.scheduler_factor + if factor >= 1.0: + self.log(f"Warning : scheduler factor is {factor} >= 1.0") + self.log("Setting it to 0.5") + factor = 0.5 + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode="min", - factor=self.config.scheduler_factor, + factor=factor, patience=self.config.scheduler_patience, verbose=VERBOSE_SCHEDULER, ) - dice_metric = DiceMetric(include_background=True, reduction="mean") + dice_metric = DiceMetric( + include_background=False, reduction="mean" + ) best_metric = -1 best_metric_epoch = -1 @@ -1467,6 +1476,15 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] + self.log( + "ETA: " + + str( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 + ) + + "minutes" + ) if ( (epoch + 1) % self.config.validation_interval == 0 @@ -1490,7 +1508,7 @@ def train(self): overlap=0.25, sw_device=self.config.device, device=self.config.device, - progress=True, + progress=False, ) except Exception as e: self.raise_error(e, "Error during validation") diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 41554e80..7aa6476e 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=1, **kwargs): + def __init__(self, in_channels=1, out_channels=2, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index cbdacf3a..7ac605e9 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -178,6 +178,11 @@ def _build(self): def make_config(self): return self.crf_params_widget.make_config() + def print_config(self): + logger.info("CRF config:") + for item in self.make_config().__dict__.items(): + logger.info(f"{item[0]}: {item[1]}") + def _check_ready(self): if len(self.label_layer_loader.layer_list) < 1: logger.warning("No label layer loaded") @@ -272,6 +277,7 @@ def _on_start(self): def _on_finish(self): self.worker = None + self.start_button.setText("Start") def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 472cccd8..157b8af7 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -802,6 +802,9 @@ def on_yield(self, result: InferenceResult): # ) if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) viewer.add_image( result.crf_results, name=f"CRF_results_image_{image_id}", diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 88991f43..86d1d317 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -846,7 +846,7 @@ def start(self): loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), @@ -982,7 +982,7 @@ def on_yield(self, report: TrainingReport): layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(layer) else: @@ -993,7 +993,7 @@ def on_yield(self, report: TrainingReport): new_layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(new_layer) self.result_layers[i].data = report.images[i] diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index e7eaf95a..1aa316d2 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -12,8 +12,8 @@ LOGGER = logging.getLogger(__name__) ############### # Global logging level setting -# LOGGER.setLevel(logging.DEBUG) -LOGGER.setLevel(logging.INFO) +LOGGER.setLevel(logging.DEBUG) +# LOGGER.setLevel(logging.INFO) ############### """ utils.py @@ -311,7 +311,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): pad = 2**n n += 1 - if pad >= 256: + if pad >= 1024: LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." From 8ac41613e988dc0312e246900eaad9a2b539ab70 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 3 May 2023 09:57:34 +0200 Subject: [PATCH 247/577] Fixes and channel labeling prototype --- napari_cellseg3d/code_models/model_workers.py | 33 +++-- .../extract_extra_channels_labels.py | 124 ++++++++++++++++++ 2 files changed, 143 insertions(+), 14 deletions(-) create mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 39e6bb91..9f38a534 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -548,12 +548,14 @@ def create_inference_result( "A layer's ID should always be 0 (default value)" ) extra_dims = len(semantic_labels.shape) - 3 - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + if semantic_labels is not None: + semantic_labels = np.swapaxes( + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + if crf_results is not None: + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims + ) return InferenceResult( image_id=i + 1, @@ -1456,6 +1458,12 @@ def train(self): optimizer.zero_grad() outputs = model(inputs) # self.log(f"Output dimensions : {outputs.shape}") + if outputs.shape[1] > 1: + outputs = outputs[ + :, 1:, :, : + ] # FIXME fix channel number + if len(outputs.shape) < 4: + outputs = outputs.unsqueeze(0) loss = self.config.loss_function(outputs, labels) loss.backward() optimizer.step() @@ -1476,15 +1484,12 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] - self.log( - "ETA: " - + str( - (time.time() - start_time) - * (self.config.max_epochs / (epoch + 1) - 1) - / 60 - ) - + "minutes" + eta = ( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 ) + self.log("ETA: " + f"{eta:.2f}" + " minutes") if ( (epoch + 1) % self.config.validation_interval == 0 diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py new file mode 100644 index 00000000..2bd0a536 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -0,0 +1,124 @@ +import numpy as np +from skimage.filters import threshold_otsu +from skimage.segmentation import expand_labels +from tqdm import tqdm + + +def extract_labels_from_channels( + nucleus_labels: np.array, + extra_channels: list, + radius: int = 4, + threshold_factor=2, + viewer=None, +): + """ + Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. + Args: + nucleus_labels (np.array): labels for the nuclei + extra_channels (list): channels arrays to extract labels from + radius: radius in which the approximation is made + + Returns: + A list of extracted labels for each extra channel + """ + labeled_channels = {} + + contrasted_channels = [] + for channel in extra_channels: + channel = (channel - np.min(channel)) / ( + np.max(channel) - np.min(channel) + ) + threshold_brightness = threshold_otsu(channel) * threshold_factor + channel_contrasted = np.where( + channel > threshold_brightness, channel, 0 + ) + contrasted_channels.append(channel_contrasted) + if viewer is not None: + viewer.add_image( + channel_contrasted, + name="channel_contrasted", + colormap="viridis", + ) + for label_id in tqdm(np.unique(nucleus_labels)): + if label_id == 0: + continue + label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + expanded = expand_labels(label_nucleus, distance=radius) + for i, channel in enumerate(contrasted_channels): + label_contrasted = np.where(expanded != 0, channel, 0) + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels[ + f"label_{label_id}_channel_{i+1}" + ] = np.count_nonzero(labeled_channel) + if np.count_nonzero(labeled_channel) > 0 and viewer is not None: + print(np.count_nonzero(labeled_channel)) + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) + + return labeled_channels + + +if __name__ == "__main__": + from pathlib import Path + + import napari + import pandas as pd + from tifffile import imread + + image_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + ) + # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + nuclei_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" + ) + extra_channels_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" + ) + extra_channels = [ + imread(str(path)) + for path in extra_channels_path.glob( + "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" + ) + ] + labels = imread(str(image_path)) + viewer = napari.Viewer() + + shift = 0 + viewer.add_image( + imread(str(nuclei_path))[ + shift : 32 + shift, shift : 32 + shift, shift : 32 + shift + ], + name="nuclei", + ) + viewer.add_labels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + [ + viewer.add_image( + channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + for channel in extra_channels + ] + + labeled_channels = extract_labels_from_channels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], + [ + c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + for c in extra_channels + ], + radius=4, + viewer=viewer, + ) + table = pd.DataFrame( + labeled_channels.items(), columns=["name", "pixels count"] + ) + print(table) + # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] + # expanded = expand_labels(labels, 4) + # viewer.add_labels(expanded) + napari.run() From c6e3c218ab4d841881265138d61e0896f0268635 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 09:18:42 +0200 Subject: [PATCH 248/577] Fixes - Fixed multi-channel instance and csv stats - Fixed rotation of inference outputs - Raised max crop size --- napari_cellseg3d/code_models/model_workers.py | 74 ++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- .../code_plugins/plugin_model_inference.py | 79 +++++++++---------- .../extract_extra_channels_labels.py | 64 +++++++++------ napari_cellseg3d/interface.py | 54 ++++++++----- napari_cellseg3d/utils.py | 6 ++ 6 files changed, 160 insertions(+), 119 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 9f38a534..9e0a5085 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -547,15 +547,15 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - extra_dims = len(semantic_labels.shape) - 3 + if semantic_labels is not None: - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) + semantic_labels = utils.correct_rotation(semantic_labels) if crf_results is not None: - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + crf_results = utils.correct_rotation(crf_results) + if instance_labels is not None: + instance_labels = utils.correct_rotation( + instance_labels + ) # TODO(cyril) check if correct return InferenceResult( image_id=i + 1, @@ -581,8 +581,6 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): semantic_labels, i + 1, ) - if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -608,10 +606,11 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"{additional_info}_Prediction_{i+1}" + + f"{additional_info}" + + f"Prediction_{i+1}" + original_filename + self.config.model_info.name - + f"_{time}_" + + f"_{time}" + filetype ) try: @@ -638,18 +637,20 @@ def aniso_transform(self, image): return image def instance_seg( - self, to_instance, image_id=0, original_filename="layer", channel=None + self, semantic_labels, image_id=0, original_filename="layer" ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) - if channel is not None: - channel_id = f"_{channel}" + if len(semantic_labels.shape) == 4: + instance_labels = np.array( + [method.run_method(ch) for ch in semantic_labels] + ) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") else: - channel_id = "" + instance_labels = method.run_method(image=semantic_labels) if self.config.filetype == "": filetype = "" @@ -661,7 +662,6 @@ def instance_seg( + "/" + f"Instance_seg_labels_{image_id}_" + original_filename - + channel_id + "_" + self.config.model_info.name + f"_{utils.get_date_time()}" @@ -720,7 +720,10 @@ def run_crf(self, image, labels, image_id=0): image, labels, config=self.config.crf_config, log=self.log ) self.save_image( - crf_results, i=image_id, additional_info="CRF", from_layer=True + crf_results, + i=image_id, + additional_info="CRF_", + from_layer=True, ) return crf_results except ValueError as e: @@ -728,14 +731,17 @@ def run_crf(self, image, labels, image_id=0): return None def stats_csv(self, instance_labels): - if self.config.compute_stats: - stats = volume_stats(instance_labels) - return stats - - # except ValueError as e: - # self.log(f"Error occurred during stats computing : {e}") - # return None - else: + try: + if self.config.compute_stats: + if len(instance_labels.shape) == 4: + stats = [volume_stats(c) for c in instance_labels] + else: + stats = [volume_stats(instance_labels)] + return stats + else: + return None + except ValueError as e: + self.log(f"Error occurred during stats computing : {e}") return None def inference_on_layer(self, image, model, post_process_transforms): @@ -753,15 +759,9 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels_results = [] - stats_results = [] - - for channel in out: - instance_labels, stats = self.get_instance_result( - channel, from_layer=True - ) - instance_labels_results.append(instance_labels) - stats_results.append(stats) + instance_labels, stats = self.get_instance_result( + semantic_labels=out, from_layer=True + ) if self.config.use_crf: crf_results = self.run_crf(image, out) @@ -770,10 +770,10 @@ def inference_on_layer(self, image, model, post_process_transforms): return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels_results, + instance_labels=instance_labels, crf_results=crf_results, from_layer=True, - stats=stats_results, + stats=stats, ) # @thread_worker(connect={"errored": self.raise_error}) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 6e7f91f3..323f8068 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -80,7 +80,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.results_filewidget.check_ready() self.crop_size_widgets = ui.IntIncrementCounter.make_n( - 3, 1, 1000, DEFAULT_CROP_SIZE + 3, 1, 10000, DEFAULT_CROP_SIZE ) self.crop_size_labels = [ ui.make_label("Size in " + axis + " of cropped volume :", self) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 157b8af7..9f093629 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -140,7 +140,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.thresholding_slider = ui.Slider( - lower=1, default=config.PostProcessConfig().thresholding.threshold_value * 100, divide_factor=100.0, @@ -437,10 +436,10 @@ def _build(self): self.anisotropy_wdgt, # anisotropy self.thresholding_checkbox, self.thresholding_slider.container, # thresholding - self.use_instance_choice, - self.instance_widgets, self.use_crf, self.crf_widgets, + self.use_instance_choice, + self.instance_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -754,61 +753,61 @@ def on_yield(self, result: InferenceResult): name=f"pred_{image_id}_{model_name}", opacity=0.8, ) + if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) if ( len(result.instance_labels) > 0 and self.worker_config.post_process_config.instance.enabled ): - for i, labels in enumerate(result.instance_labels): - # labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(result.instance_labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(result.instance_labels, name=name) from napari_cellseg3d.utils import LOGGER as log - log.debug(f"len stats : {len(result.stats)}") + if result.stats is not None and isinstance( + result.stats, list + ): + log.debug(f"len stats : {len(result.stats)}") - for i, stats in enumerate(result.stats): - # stats = result.stats + for i, stats in enumerate(result.stats): + # stats = result.stats - if ( - self.worker_config.compute_stats - and stats is not None - ): - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances in channel {i} : {stats.number_objects[0]}" - ) + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) - - if result.crf_results is not None: - logger.debug( - f"CRF results shape : {result.crf_results.shape}" - ) - viewer.add_image( - result.crf_results, - name=f"CRF_results_image_{image_id}", - colormap="viridis", - ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py index 2bd0a536..70ee10b6 100644 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -4,8 +4,8 @@ from tqdm import tqdm -def extract_labels_from_channels( - nucleus_labels: np.array, +def extract_labels_from_channels( # TODO add separate channels results + nuclei_labels: np.array, extra_channels: list, radius: int = 4, threshold_factor=2, @@ -14,15 +14,14 @@ def extract_labels_from_channels( """ Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. Args: - nucleus_labels (np.array): labels for the nuclei + nuclei_labels (np.array): labels for the nuclei extra_channels (list): channels arrays to extract labels from radius: radius in which the approximation is made Returns: A list of extracted labels for each extra channel """ - labeled_channels = {} - + labeled_channels = [] contrasted_channels = [] for channel in extra_channels: channel = (channel - np.min(channel)) / ( @@ -39,31 +38,54 @@ def extract_labels_from_channels( name="channel_contrasted", colormap="viridis", ) - for label_id in tqdm(np.unique(nucleus_labels)): + for label_id in tqdm(np.unique(nuclei_labels)): if label_id == 0: continue - label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) expanded = expand_labels(label_nucleus, distance=radius) + restricted = np.where(expanded != 0, nuclei_labels, 0) + overlap = np.where(restricted != label_id, restricted, 0) + for i, channel in enumerate(contrasted_channels): label_contrasted = np.where(expanded != 0, channel, 0) - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels[ - f"label_{label_id}_channel_{i+1}" - ] = np.count_nonzero(labeled_channel) - if np.count_nonzero(labeled_channel) > 0 and viewer is not None: - print(np.count_nonzero(labeled_channel)) - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) + if overlap.any() != 0: + max_labeled = 0 + for overlap_id in np.unique(overlap): + if overlap_id == 0: + continue + assigned_pixels = np.count_nonzero( + np.where(overlap == overlap_id, channel, 0) + ) + if assigned_pixels > max_labeled: + max_labeled = assigned_pixels + max_label_id = overlap_id + if label_id != max_label_id: + labeled_channels.append( + np.zeros_like(label_contrasted) + ) + else: + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels.append(labeled_channel) + if ( + np.count_nonzero(labeled_channel) > 0 + and viewer is not None + ): + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) - return labeled_channels + cat_labels = np.zeros_like(nuclei_labels) + for labels in np.unique(labeled_channels): + if labels == 0: + continue + cat_labels += np.where(labels != 0, labels, 0) + return cat_labels if __name__ == "__main__": from pathlib import Path import napari - import pandas as pd from tifffile import imread image_path = ( @@ -114,10 +136,8 @@ def extract_labels_from_channels( radius=4, viewer=viewer, ) - table = pd.DataFrame( - labeled_channels.items(), columns=["name", "pixels count"] - ) - print(table) + + viewer.add_labels(labeled_channels) # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] # expanded = expand_labels(labels, 4) # viewer.add_labels(expanded) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 55e5abb3..9d06863e 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -469,6 +469,11 @@ def __init__( ): super().__init__(orientation, parent) + if upper <= lower: + raise ValueError( + "The minimum value cannot be below the maximum one" + ) + self.setMaximum(upper) self.setMinimum(lower) self.setSingleStep(step) @@ -544,23 +549,29 @@ def _warn_outside_bounds(self, default): def _update_slider(self): """Update slider when value is changed""" - if self._value_label.text() == "": - return + try: + if self._value_label.text() == "": + return - value = float(self._value_label.text()) * self._divide_factor + value = float(self._value_label.text()) * self._divide_factor - if value < self.minimum(): - self.slider_value = self.minimum() - return - if value > self.maximum(): - self.slider_value = self.maximum() - return + if value < self.minimum(): + self.slider_value = self.minimum() + return + if value > self.maximum(): + self.slider_value = self.maximum() + return - self.slider_value = value + self.slider_value = value + except Exception as e: + logger.error(e) def _update_value_label(self): """Update label, to connect to when slider is dragged""" - self._value_label.setText(str(self.value_text)) + try: + self._value_label.setText(str(self.value_text)) + except Exception as e: + logger.error(e) @property def tooltips(self): @@ -596,16 +607,21 @@ def value_text(self): def slider_value(self, value: int): """Set a value (int) divided by self._divide_factor""" if value < self.minimum() or value > self.maximum(): - raise ValueError( - f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + logger.error( + ValueError( + f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + ) ) - self.setValue(int(value)) - - divided = value / self._divide_factor - if self._divide_factor == 1.0: - divided = int(divided) - self._value_label.setText(str(divided)) + try: + self.setValue(int(value)) + + divided = value / self._divide_factor + if self._divide_factor == 1.0: + divided = int(divided) + self._value_label.setText(str(divided)) + except Exception as e: + logger.error(e) class AnisotropyWidgets(QWidget): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 1aa316d2..75c9734e 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -202,6 +202,12 @@ def dice_coeff(y_true, y_pred): return score +def correct_rotation(image): + """Rotates the exes 0 and 2 in [DHW] section of image array""" + extra_dims = len(image) - 3 + return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, From 88bc8240b31eddb5a4cdb7027317fa151753f6d6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 14:42:02 +0200 Subject: [PATCH 249/577] Update plugin_model_inference.py --- napari_cellseg3d/code_plugins/plugin_model_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 9f093629..df64a625 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -762,9 +762,8 @@ def on_yield(self, result: InferenceResult): name=f"CRF_results_image_{image_id}", colormap="viridis", ) - if ( - len(result.instance_labels) > 0 + result.instance_labels is not None and self.worker_config.post_process_config.instance.enabled ): method_name = ( From 2b4ae10f9ce17831b637d7f6b97e537e9e898a34 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 6 May 2023 09:56:17 +0200 Subject: [PATCH 250/577] Update plugin_crop.py --- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 323f8068..46c2cfb2 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -49,7 +49,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.label_layer_loader.layer_list.label.setText("Image 2") self.crop_second_image_choice = ui.CheckBox( - "Crop another\nimage simultaneously", + "Crop another\nimage/label simultaneously", ) self.crop_second_image_choice.toggled.connect( self._toggle_second_image_io_visibility From ba67da4b3268540679b61a8c71b618859b2f89b8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 10:16:58 +0200 Subject: [PATCH 251/577] Fixed patch_func sample number mismatch --- napari_cellseg3d/code_models/model_workers.py | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 9e0a5085..88d374a0 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1230,30 +1230,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if do_sampling: - sample_loader = Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=self.config.num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), - ), - EnsureTyped(keys=["image", "label"]), - ] - ) if self.config.do_augmentation: train_transforms = ( @@ -1285,6 +1261,31 @@ def train(self): ] ) # self.log("Loading dataset...\n") + def get_loader_func(num_samples): + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) + ), + ), + EnsureTyped(keys=["image", "label"]), + ] + ) + if do_sampling: # if there is only one volume, split samples # TODO(cyril) : maybe implement something in user config to toggle this behavior @@ -1297,11 +1298,17 @@ def train(self): self.config.num_samples * (1 - self.config.validation_percent) ) + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) else: num_train_samples = ( num_val_samples ) = self.config.num_samples + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) + + logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" @@ -1311,14 +1318,14 @@ def train(self): train_ds = PatchDataset( data=self.train_files, transform=train_transforms, - patch_func=sample_loader, + patch_func=sample_loader_train, samples_per_image=num_train_samples, ) logger.debug("val_ds") val_ds = PatchDataset( data=self.val_files, transform=val_transforms, - patch_func=sample_loader, + patch_func=sample_loader_eval, samples_per_image=num_val_samples, ) From e63bb588913b593edd393c881ed936982ab8ed86 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:08:52 +0200 Subject: [PATCH 252/577] Testing relabel tools --- napari_cellseg3d/dev_scripts/correct_labels.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2ab60332..9862c3fa 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -367,8 +367,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") +# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") # -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# image_path = str(im_path / "volumes/images.tif") +# gt_labels_path = str(im_path / "labels/testing_im.tif") +# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) From 0883fe90c502e958aab53b09f396d9e89e21b1fc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:38:45 +0200 Subject: [PATCH 253/577] Fixes in inference --- napari_cellseg3d/code_models/model_workers.py | 2 ++ napari_cellseg3d/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 88d374a0..754a5007 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -504,6 +504,8 @@ def model_output_wrapper(inputs): sw_device=self.config.device, device=dataset_device, overlap=window_overlap, + mode="gaussian", + sigma_scale=0.01, progress=True, ) except Exception as e: diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 75c9734e..86754ad0 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -204,7 +204,7 @@ def dice_coeff(y_true, y_pred): def correct_rotation(image): """Rotates the exes 0 and 2 in [DHW] section of image array""" - extra_dims = len(image) - 3 + extra_dims = len(image.shape) - 3 return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) From 8546f1c564100ac5943c81a6cee73b52bf8ecf9b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 14:48:14 +0200 Subject: [PATCH 254/577] add model template + fix test + wnet loading opti - test fixes - changed crf input reqs - adapted instance seg for several channels --- napari_cellseg3d/_tests/test_models.py | 10 ++- .../_tests/test_plugin_inference.py | 11 ++-- napari_cellseg3d/_tests/test_training.py | 11 ++-- napari_cellseg3d/code_models/crf.py | 11 ++-- .../code_models/model_instance_seg.py | 29 ++++++++- napari_cellseg3d/code_models/model_workers.py | 62 +++++++++---------- .../code_models/models/TEMPLATE_model.py | 20 ++++++ .../code_models/models/model_SwinUNetR.py | 13 +++- .../code_models/models/model_WNet.py | 19 ++++++ .../code_plugins/plugin_convert.py | 2 +- 10 files changed, 129 insertions(+), 59 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/TEMPLATE_model.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 1fc15872..35af8c76 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -15,6 +15,8 @@ def test_correct_shape_for_crf(): def test_model_list(): for model_name in MODEL_LIST.keys(): + # if model_name=="test": + # continue dims = 128 test = MODEL_LIST[model_name]( input_img_size=[dims, dims, dims], @@ -39,18 +41,20 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) - # assert res > 0 + assert 0 <= res <= 1 def test_crf(qtbot): dims = 8 mock_image = np.random.rand(1, dims, dims, dims) mock_label = np.random.rand(2, dims, dims, dims) - - crf = CRFWorker(mock_image, mock_label) + assert len(mock_label.shape) == 4 + crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) + assert len(result.shape) == 4 + assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] crf.yielded.connect(on_yield) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 66c50fba..3dafeabc 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,9 +3,10 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_inference(make_napari_viewer, qtbot): @@ -28,9 +29,9 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.setCurrentIndex(-1) # widget.start() # takes too long on Github Actions # assert widget.worker is not None diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 21731ba1..921a6d26 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,9 +2,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_training(make_napari_viewer, qtbot): @@ -32,9 +33,9 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) # widget.start() # assert widget.worker is not None diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 1b8dce28..21caf35f 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -57,10 +57,10 @@ def correct_shape_for_crf(image, desired_dims=4): if len(image.shape) == desired_dims: return image if len(image.shape) > desired_dims: - if image.shape[0] > 1: - raise ValueError( - f"Image shape {image.shape} might have several channels" - ) + # if image.shape[0] > 1: + # raise ValueError( + # f"Image shape {image.shape} might have several channels" + # ) image = np.squeeze(image, axis=0) if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) @@ -200,7 +200,6 @@ def __init__( self.config = config self.log = log - # TODO(cyril) : add progress bar into log ? or do it in inference def _run_crf_job(self): """Runs the CRF post-processing step for the W-Net.""" if not CRF_INSTALLED: @@ -211,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + # labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index d1a03eec..0c3c6c6b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,3 +1,4 @@ +import abc from dataclasses import dataclass from functools import partial from typing import List @@ -82,8 +83,32 @@ def __init__( ) self.counters.append(getattr(self, widget)) + @abc.abstractmethod def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") + raise NotImplementedError() + + def _make_list_from_channels( + self, image + ): # TODO(cyril) : adapt to batch dimension + if len(image.shape) > 4: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" + ) + if len(image.shape) == 4: + image = np.squeeze(image) + if len(image.shape) == 4: + return [im for im in image] + elif len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) + else: + return [image] + + def run_method_on_channels(self, image): + image_list = self._make_list_from_channels(image) # FIXME rename + result = np.array([self.run_method(im) for im in image_list]) + return result.squeeze() class InstanceMethod: @@ -611,7 +636,7 @@ def run_method(self, volume): """ method = self.methods[self.method_choice.currentText()] - return method.run_method(volume) + return method.run_method_on_channels(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 754a5007..93f9908b 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -645,17 +645,11 @@ def instance_seg( self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - - if len(semantic_labels.shape) == 4: - instance_labels = np.array( - [method.run_method(ch) for ch in semantic_labels] - ) - self.log(f"DEBUG instance results shape : {instance_labels.shape}") - else: - instance_labels = method.run_method(image=semantic_labels) + instance_labels = method.run_method_on_channels(semantic_labels) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") if self.config.filetype == "": - filetype = "" + filetype = ".tif" else: filetype = "_" + self.config.filetype @@ -855,7 +849,8 @@ def inference(self): weights = str( PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - model.load_state_dict( + + model.load_state_dict( # note that this is redefined in WNet_ torch.load( weights, map_location=self.config.device, @@ -1232,7 +1227,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if self.config.do_augmentation: train_transforms = ( Compose( # TODO : figure out which ones and values ? @@ -1262,31 +1256,32 @@ def train(self): EnsureTyped(keys=["image", "label"]), ] ) + # self.log("Loading dataset...\n") def get_loader_func(num_samples): - return Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) ), - EnsureTyped(keys=["image", "label"]), - ] - ) + ), + EnsureTyped(keys=["image", "label"]), + ] + ) if do_sampling: # if there is only one volume, split samples @@ -1310,7 +1305,6 @@ def get_loader_func(num_samples): sample_loader_train = get_loader_func(num_train_samples) sample_loader_eval = get_loader_func(num_val_samples) - logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" diff --git a/napari_cellseg3d/code_models/models/TEMPLATE_model.py b/napari_cellseg3d/code_models/models/TEMPLATE_model.py new file mode 100644 index 00000000..f68e5f4f --- /dev/null +++ b/napari_cellseg3d/code_models/models/TEMPLATE_model.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class ModelTemplate_(ABC): + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 05819e22..484890d1 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -9,12 +9,19 @@ class SwinUNETR_(SwinUNETR): use_default_training = True weights_file = "Swin64_best_metric.pth" - def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + def __init__( + self, + in_channels=1, + out_channels=1, + input_img_size=128, + use_checkpoint=True, + **kwargs, + ): try: super().__init__( input_img_size, - in_channels=1, - out_channels=1, + in_channels=in_channels, + out_channels=out_channels, feature_size=48, use_checkpoint=use_checkpoint, **kwargs, diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 4a9ff70d..86a1f7e6 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,5 +1,12 @@ +from typing import TypeVar + +from torch.nn import Module + +# local from napari_cellseg3d.code_models.models.wnet.model import WNet +T = TypeVar("T", bound="Module") + class WNet_(WNet): use_default_training = False @@ -20,6 +27,9 @@ def __init__( num_classes=num_classes, ) + def train(self: T, mode: bool = True) -> T: + raise NotImplementedError("Training not implemented for WNet") + def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. @@ -27,3 +37,12 @@ def forward(self, x): enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc + + def load_state_dict(self, state_dict, strict=False): + """Load the model state dict for inference, without the decoder weights.""" + encoder_checkpoint = state_dict.copy() + for k in state_dict.keys(): + if k.startswith("decoder"): + encoder_checkpoint.pop(k) + # print(encoder_checkpoint.keys()) + super().load_state_dict(encoder_checkpoint, strict=strict) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index f7b476d0..8353632e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -363,7 +363,7 @@ def _start(self): elif self.folder_choice.isChecked(): if len(self.images_filepaths) != 0: images = [ - self.instance_widgets.run_method(imread(file)) + self.instance_widgets.run_method_on_channels(imread(file)) for file in self.images_filepaths ] utils.save_folder( From 40f1c79f17c54d2c6b542c21452fdbf72a5c8db0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 15:16:25 +0200 Subject: [PATCH 255/577] Update model_WNet.py --- napari_cellseg3d/code_models/models/model_WNet.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 86a1f7e6..f07ac517 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,12 +1,6 @@ -from typing import TypeVar - -from torch.nn import Module - # local from napari_cellseg3d.code_models.models.wnet.model import WNet -T = TypeVar("T", bound="Module") - class WNet_(WNet): use_default_training = False @@ -27,8 +21,8 @@ def __init__( num_classes=num_classes, ) - def train(self: T, mode: bool = True) -> T: - raise NotImplementedError("Training not implemented for WNet") + # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): """Forward ENCODER pass of the W-Net model. From 6a4896dd4379ff2ccb258795af617caa41a4b365 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 13 May 2023 10:29:39 +0200 Subject: [PATCH 256/577] Update model_VNet.py --- napari_cellseg3d/code_models/models/model_VNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 7aa6476e..41554e80 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=2, **kwargs): + def __init__(self, in_channels=1, out_channels=1, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs From 8f63ec2ceae05f083c67e32262138eaa26028b8e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:51:02 +0200 Subject: [PATCH 257/577] Fixed folder creation when saving to folder --- napari_cellseg3d/code_models/crf.py | 2 +- napari_cellseg3d/code_plugins/plugin_convert.py | 10 +++++----- napari_cellseg3d/code_plugins/plugin_crf.py | 2 +- napari_cellseg3d/utils.py | 3 +++ 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 21caf35f..aa9cce75 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -210,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - # labels = correct_shape_for_crf(labels) + labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 8353632e..77aa9af6 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -46,7 +46,7 @@ def __init__(self, viewer: "napari.Viewer.viewer", parent=None): self.aniso_widgets = ui.AnisotropyWidgets(self, always_visible=True) self.start_btn = ui.Button("Start", self._start) - self.results_path = Path.home() / Path("cellseg3d/anisotropy") + self.results_path = str(Path.home() / Path("cellseg3d/anisotropy")) self.results_filewidget.text_field.setText(str(self.results_path)) self.results_filewidget.check_ready() @@ -76,7 +76,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + utils.mkdir_from_str(self.results_path) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -175,7 +175,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -342,7 +342,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -436,7 +436,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.binarize_counter.value() if self.layer_choice: diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 7ac605e9..d8407a0f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -238,7 +238,7 @@ def _start(self): self.result_layer = self.label_layer_loader.layer() self.result_name = self.label_layer_loader.layer_name() - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) image_list = [self.image_layer_loader.layer_data()] labels_list = [self.label_layer_loader.layer_data()] diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 86754ad0..6e2f7341 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -131,6 +131,9 @@ def normalize_x(image): image = image / 127.5 - 1 return image +def mkdir_from_str(path: str, exist_ok=True, parents=True): + Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) + def normalize_y(image): """Normalizes the values of an image array to be between [0;1] rather than [0;255] From c0c3d436ee979c5730de1acbc0dfb4ddc4740e67 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:54:07 +0200 Subject: [PATCH 258/577] Fix check_ready for results filewidget --- napari_cellseg3d/interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 9d06863e..6c5eb5c3 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -854,6 +854,9 @@ def __init__( self.build() self.check_ready() + if self._required: + self._text_field.textChanged.connect(self.check_ready) + def build(self): """Builds the layout of the widget""" add_widgets( @@ -912,7 +915,7 @@ def required(self, is_required): try: self.text_field.textChanged.disconnect(self.check_ready) except TypeError: - return + pass self.check_ready() self._required = is_required From bd6ebe33d90decb478637d9dad5be349e977ccce Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 11:28:33 +0200 Subject: [PATCH 259/577] Added remapping in WNet + ruff config --- .pre-commit-config.yaml | 3 ++ napari_cellseg3d/code_models/model_workers.py | 51 ++++++++----------- napari_cellseg3d/utils.py | 48 ++++++++++++----- pyproject.toml | 36 ++++++++++++- 4 files changed, 93 insertions(+), 45 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61ecaae5..f9fe2853 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,9 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + - id: check-toml # - repo: https://github.com/pycqa/isort # rev: 5.12.0 # hooks: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 93f9908b..4ce4d180 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -119,7 +119,7 @@ def show_progress(_, block_size, __): # count, block_size, total_size logger.info(message) return - with open(json_path) as f: + with Path.open(json_path) as f: neturls = json.load(f) if model_name in neturls: url = neturls[model_name] @@ -259,8 +259,7 @@ def create_inference_dict(images_filepaths): Returns: dict: list of image paths from loaded folder""" - data_dicts = [{"image": image_name} for image_name in images_filepaths] - return data_dicts + return [{"image": image_name} for image_name in images_filepaths] def set_download_log(self, widget): self.downloader.log_widget = widget @@ -472,10 +471,9 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - if self.config.keep_on_cpu: - dataset_device = "cpu" - else: - dataset_device = self.config.device + dataset_device = ( + "cpu" if self.config.keep_on_cpu else self.config.device + ) if self.config.sliding_window_config.is_enabled(): window_size = self.config.sliding_window_config.window_size @@ -492,6 +490,7 @@ def model_output( # outputs = model(inputs) def model_output_wrapper(inputs): + inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) @@ -509,7 +508,7 @@ def model_output_wrapper(inputs): progress=True, ) except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.debug("failed to run sliding window inference") self.raise_error(e, "Error during sliding window inference") logger.debug(f"Inference output shape: {outputs.shape}") @@ -520,11 +519,9 @@ def model_output_wrapper(inputs): if post_process: out = np.array(out).astype(np.float32) out = np.squeeze(out) - return out - else: - return out + return out except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Error during sliding window inference") # sys.stdout = old_stdout # sys.stderr = old_stderr @@ -635,8 +632,7 @@ def aniso_transform(self, image): padding_mode="empty", ) return anisotropic_transform(image[0]) - else: - return image + return image def instance_seg( self, semantic_labels, image_id=0, original_filename="layer" @@ -648,10 +644,11 @@ def instance_seg( instance_labels = method.run_method_on_channels(semantic_labels) self.log(f"DEBUG instance results shape : {instance_labels.shape}") - if self.config.filetype == "": - filetype = ".tif" - else: - filetype = "_" + self.config.filetype + filetype = ( + ".tif" + if self.config.filetype == "" + else "_" + self.config.filetype + ) instance_filepath = ( self.config.results_path @@ -733,9 +730,9 @@ def stats_csv(self, instance_labels): stats = [volume_stats(c) for c in instance_labels] else: stats = [volume_stats(instance_labels)] - return stats else: - return None + stats = None + return stats except ValueError as e: self.log(f"Error occurred during stats computing : {e}") return None @@ -759,10 +756,7 @@ def inference_on_layer(self, image, model, post_process_transforms): semantic_labels=out, from_layer=True ) - if self.config.use_crf: - crf_results = self.run_crf(image, out) - else: - crf_results = None + crf_results = self.run_crf(image, out) if self.config.use_crf else None return self.create_inference_result( semantic_labels=out, @@ -944,7 +938,7 @@ def inference(self): model.to("cpu") # self.quit() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Inference failed") self.quit() finally: @@ -1175,10 +1169,7 @@ def train(self): do_sampling = self.config.sampling - if do_sampling: - size = self.config.sample_size - else: - size = check + size = self.config.sample_size if do_sampling else check model = model_class( # FIXME check if correct input_img_size=utils.get_padding_dim(size), use_checkpoint=True @@ -1411,7 +1402,7 @@ def get_loader_func(num_samples): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") - logger.error(e, exc_info=True) + logger.exception(e) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 6e2f7341..7ca29e00 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING, Union import napari import numpy as np @@ -9,6 +10,9 @@ from skimage.filters import gaussian from tifffile import imread, imwrite +if TYPE_CHECKING: + import torch + LOGGER = logging.getLogger(__name__) ############### # Global logging level setting @@ -128,8 +132,8 @@ def normalize_x(image): Returns: array: normalized value for the image """ - image = image / 127.5 - 1 - return image + return image / 127.5 - 1 + def mkdir_from_str(path: str, exist_ok=True, parents=True): Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) @@ -144,8 +148,7 @@ def normalize_y(image): Returns: array: normalized value for the image """ - image = image / 255 - return image + return image / 255 def sphericity_volume_area(volume, surface_area): @@ -199,10 +202,9 @@ def dice_coeff(y_true, y_pred): y_true_f = y_true.flatten() y_pred_f = y_pred.flatten() intersection = np.sum(y_true_f * y_pred_f) - score = (2.0 * intersection + smooth) / ( + return (2.0 * intersection + smooth) / ( np.sum(y_true_f) + np.sum(y_pred_f) + smooth ) - return score def correct_rotation(image): @@ -211,6 +213,27 @@ def correct_rotation(image): return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) +def normalize_max(image): + """Normalizes an image using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image.reshape(shape) + return image + + +def remap_image( + image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 +): + """Normalizes a numpy array or Tensor using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image * (new_max - new_min) + new_min + image = image.reshape(shape) + return image + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, @@ -276,10 +299,11 @@ def time_difference(time_start, time_finish, as_string=True): minutes = f"{int(minutes[0])}".zfill(2) seconds = f"{int(seconds[0])}".zfill(2) - if as_string: - return f"{hours}:{minutes}:{seconds}" - else: - return [hours, minutes, seconds] + return ( + f"{hours}:{minutes}:{seconds}" + if as_string + else [hours, minutes, seconds] + ) def get_padding_dim(image_shape, anisotropy_factor=None): @@ -549,10 +573,8 @@ def load_images( "Loading as folder not implemented yet. Use napari to load as folder" ) # images_original = dask_imread(filename_pattern_original) - else: - images_original = imread(filename_pattern_original) # tifffile imread - return images_original + return imread(filename_pattern_original) # tifffile imread # def load_predicted_masks(mito_mask_dir, er_mask_dir, filetype): diff --git a/pyproject.toml b/pyproject.toml index d223072a..81d2a788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,11 +46,43 @@ where = ["."] [tool.ruff] select = [ "E", "F", "W", - "I", + "A", "B", + "G", + "I", + "PT", + "PTH", + "RET", + "SIM", + "TCH", + "NPY", ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) -ignore = ["E501", "E741"] +# and 'G004' (do not use f-strings in logging) +ignore = ["E501", "E741", "G004"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] [tool.black] line-length = 79 From 529598fe5a176be75c69b35ab1369f6b28252e1d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 13:21:06 +0200 Subject: [PATCH 260/577] Run new hooks --- napari_cellseg3d/_tests/test_models.py | 13 +- .../_tests/test_weight_download.py | 2 +- napari_cellseg3d/code_models/crf.py | 25 ++-- ...stance_seg.py => instance_segmentation.py} | 19 ++- .../code_models/model_framework.py | 11 +- .../code_models/models/model_SwinUNetR.py | 2 +- .../code_models/models/model_TRAILMAP_MS.py | 2 +- .../code_models/models/model_WNet.py | 8 +- .../code_models/models/unet/buildingblocks.py | 3 +- .../code_models/models/wnet/soft_Ncuts.py | 4 +- .../{model_workers.py => workers.py} | 2 +- .../code_plugins/plugin_convert.py | 127 +++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 6 +- .../code_plugins/plugin_metrics.py | 12 +- .../code_plugins/plugin_model_inference.py | 11 +- .../code_plugins/plugin_model_training.py | 18 +-- .../code_plugins/plugin_review.py | 11 +- .../code_plugins/plugin_review_dock.py | 5 +- napari_cellseg3d/config.py | 8 +- .../dev_scripts/artefact_labeling.py | 16 +-- .../dev_scripts/correct_labels.py | 7 +- napari_cellseg3d/interface.py | 58 ++++---- pyproject.toml | 2 + 23 files changed, 191 insertions(+), 181 deletions(-) rename napari_cellseg3d/code_models/{model_instance_seg.py => instance_segmentation.py} (99%) rename napari_cellseg3d/code_models/{model_workers.py => workers.py} (99%) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35af8c76..35174b85 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,20 +1,23 @@ import numpy as np import torch +from numpy.random import PCG64, Generator from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +rand_gen = Generator(PCG64(12345)) + def test_correct_shape_for_crf(): - test = np.random.rand(1, 1, 8, 8, 8) + test = rand_gen.random(size=(1, 1, 8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) - test = np.random.rand(8, 8, 8) + test = rand_gen.random(size=(8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) def test_model_list(): - for model_name in MODEL_LIST.keys(): + for model_name in MODEL_LIST: # if model_name=="test": # continue dims = 128 @@ -46,8 +49,8 @@ def test_soft_ncuts_loss(): def test_crf(qtbot): dims = 8 - mock_image = np.random.rand(1, dims, dims, dims) - mock_label = np.random.rand(2, dims, dims, dims) + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index b9d4abe5..be694d99 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( PRETRAINED_WEIGHTS_DIR, WeightsDownloader, ) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index aa9cce75..8c311059 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,17 +54,15 @@ def correct_shape_for_crf(image, desired_dims=4): - if len(image.shape) == desired_dims: - return image if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( # f"Image shape {image.shape} might have several channels" # ) image = np.squeeze(image, axis=0) - if len(image.shape) < desired_dims: + elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) - return correct_shape_for_crf(image) + return image def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): @@ -185,8 +183,8 @@ class CRFWorker(GeneratorWorker): def __init__( self, - images_list, - labels_list, + images_list: list, + labels_list: list, config: CRFConfig = None, log=None, ): @@ -205,16 +203,19 @@ def _run_crf_job(self): if not CRF_INSTALLED: raise ImportError("pydensecrf is not installed.") - for image, labels in zip(self.images, self.labels): - if image.shape[-3:] != labels.shape[-3:]: + if len(self.images) != len(self.labels): + raise ValueError("Number of images and labels must be the same.") + + for i in range(len(self.images)): + if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + im = correct_shape_for_crf(self.labels[i]) + prob = correct_shape_for_crf(self.labels[i]) yield crf( - image, - labels, + im, + prob, self.config.sa, self.config.sb, self.config.sg, diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/instance_segmentation.py similarity index 99% rename from napari_cellseg3d/code_models/model_instance_seg.py rename to napari_cellseg3d/code_models/instance_segmentation.py index 0c3c6c6b..5de7ab0c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -94,16 +94,16 @@ def _make_list_from_channels( raise ValueError( f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" ) + if len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) if len(image.shape) == 4: image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - elif len(image.shape) < 2: - raise ValueError( - f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" - ) - else: return [image] + return None def run_method_on_channels(self, image): image_list = self._make_list_from_channels(image) # FIXME rename @@ -353,12 +353,10 @@ def to_instance(image, is_file_path=False): image = [imread(image)] # image = image.compute() - result = binary_watershed( + return binary_watershed( image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 ) # FIXME add params from utils plugin - return result - def to_semantic(image, is_file_path=False): """Converts a **ground-truth** label to semantic (binary 0/1) labels. @@ -375,8 +373,7 @@ def to_semantic(image, is_file_path=False): # image = image.compute() image[image >= 1] = 1 - result = image.astype(np.uint16) - return result + return image.astype(np.uint16) def volume_stats(volume_image): @@ -620,7 +617,7 @@ def _build(self): self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): + for name in self.instance_widgets: if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 37fc6a49..60644916 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,8 +1,11 @@ from pathlib import Path +from typing import TYPE_CHECKING -import napari import torch +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QProgressBar, QSizePolicy @@ -126,7 +129,7 @@ def save_log(self): path = self.results_path if len(log) != 0: - with open( + with Path.open( path + f"/Log_report_{utils.get_date_time()}.txt", "x", ) as f: @@ -152,8 +155,8 @@ def save_log_to_path(self, path): ) if len(log) != 0: - with open( - path, + with Path.open( + Path(path), "x", ) as f: f.write(log) diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 484890d1..2d7b5ef6 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -27,7 +27,7 @@ def __init__( **kwargs, ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( input_img_size, in_channels=1, diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 1123173a..baf8635d 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -16,7 +16,7 @@ def __init__(self, in_channels=1, out_channels=1, **kwargs): in_channels=in_channels, out_channels=out_channels, **kwargs ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( in_channels=in_channels, out_channels=out_channels ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index f07ac517..7235bd61 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -28,14 +28,14 @@ def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. """ - enc = self.forward_encoder(x) - # dec = self.forward_decoder(enc) - return enc + return self.forward_encoder(x) + # enc = self.forward_encoder(x) + # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" encoder_checkpoint = state_dict.copy() - for k in state_dict.keys(): + for k in state_dict: if k.startswith("decoder"): encoder_checkpoint.pop(k) # print(encoder_checkpoint.keys()) diff --git a/napari_cellseg3d/code_models/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py index 73913ab8..ce7d378f 100644 --- a/napari_cellseg3d/code_models/models/unet/buildingblocks.py +++ b/napari_cellseg3d/code_models/models/unet/buildingblocks.py @@ -422,8 +422,7 @@ def forward(self, encoder_features, x): def _joining(encoder_features, x, concat): if concat: return torch.cat((encoder_features, x), dim=1) - else: - return encoder_features + x + return encoder_features + x def create_encoders( diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 4e84579f..938292c2 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -206,6 +206,7 @@ def forward(self, labels, inputs): return torch.add(torch.neg(loss), K) """ + return None def gaussian_kernel(self, radius, sigma): """Computes the Gaussian kernel. @@ -348,5 +349,4 @@ def get_weights(self, inputs): 1, 1, self.W_X.shape[0], self.W_X.shape[1] ) # (1, 1, H*W*D, H*W*D) - W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) - return W + return torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/workers.py similarity index 99% rename from napari_cellseg3d/code_models/model_workers.py rename to napari_cellseg3d/code_models/workers.py index 4ce4d180..c1ed62fd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -54,7 +54,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.crf import crf_with_config -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( ImageStats, volume_stats, ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 77aa9af6..4357e51e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -7,7 +7,7 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceWidgets, clear_small_objects, threshold, @@ -98,18 +98,19 @@ def _start(self): f"isotropic_{layer.name}", ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -193,18 +194,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -274,18 +276,19 @@ def _start(self): utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): @@ -360,18 +363,19 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.instance_widgets.run_method_on_channels(imread(file)) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.instance_widgets.run_method_on_channels(imread(file)) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -454,18 +458,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 46c2cfb2..a27b4baa 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -157,8 +157,10 @@ def _build(self): dim_group_l.addWidget(self.aniso_widgets) [ dim_group_l.addWidget(widget, alignment=ui.ABS_AL) - for list in zip(self.crop_size_labels, self.crop_size_widgets) - for widget in list + for widget_list in zip( + self.crop_size_labels, self.crop_size_widgets + ) + for widget in widget_list ] dim_group_w.setLayout(dim_group_l) layout.addWidget(dim_group_w) diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 114025f6..2a6e713c 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -1,5 +1,6 @@ +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt -import napari import numpy as np from matplotlib.backends.backend_qt5agg import ( FigureCanvasQTAgg as FigureCanvas, @@ -8,9 +9,12 @@ from monai.transforms import SpatialPad, ToTensor from tifffile import imread +if TYPE_CHECKING: + import napari + from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.instance_segmentation import to_semantic from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder DEFAULT_THRESHOLD = 0.5 @@ -187,11 +191,11 @@ def compute_dice(self): self.canvas = ( None # kind of terrible way to stack plots... but it works. ) - id = 0 + image_id = 0 for ground_path, pred_path in zip( self.images_filepaths, self.labels_filepaths ): - id += 1 + image_id += 1 ground = imread(ground_path) pred = imread(pred_path) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index df64a625..bb46617d 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,18 +1,21 @@ from functools import partial +from typing import TYPE_CHECKING -import napari import numpy as np import pandas as pd +if TYPE_CHECKING: + import napari + # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceMethod, InstanceWidgets, ) -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.workers import ( InferenceResult, InferenceWorker, ) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 86d1d317..35a16799 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,9 +1,9 @@ import shutil from functools import partial from pathlib import Path +from typing import TYPE_CHECKING import matplotlib.pyplot as plt -import napari import numpy as np import pandas as pd import torch @@ -12,6 +12,9 @@ ) from matplotlib.figure import Figure +if TYPE_CHECKING: + import napari + # MONAI from monai.losses import ( DiceCELoss, @@ -29,7 +32,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( TrainingReport, TrainingWorker, ) @@ -414,11 +417,10 @@ def check_ready(self): * False and displays a warning if not """ - if self.images_filepaths != [] and self.labels_filepaths != []: - return True - else: + if self.images_filepaths == [] and self.labels_filepaths != []: logger.warning("Image and label paths are not correctly set") return False + return True def _build(self): """Builds the layout of the widget and creates the following tabs and prompts: @@ -999,7 +1001,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1131,7 +1133,7 @@ def update_loss_plot(self, loss, metric): epoch = len(loss) if epoch < self.worker_config.validation_interval * 2: return - elif epoch == self.worker_config.validation_interval * 2: + if epoch == self.worker_config.validation_interval * 2: bckgrd_color = (0, 0, 0, 0) # '#262930' with plt.style.context("dark_background"): self.canvas = FigureCanvas(Figure(figsize=(10, 1.5))) @@ -1167,7 +1169,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 235595e4..dd98bcd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -178,11 +178,10 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None: - if cfg.image.shape != cfg.labels.shape: - logger.warning( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: + logger.warning( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): @@ -400,7 +399,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index 8753a642..f634d117 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,9 +1,12 @@ from datetime import datetime, timedelta from pathlib import Path +from typing import TYPE_CHECKING -import napari import pandas as pd +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QVBoxLayout, QWidget diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 8a7c1565..5c0b34be 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,7 +6,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.instance_segmentation import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ @@ -89,9 +89,9 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info( - "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) - ) + logger.info("Model list :") + for model_name in MODEL_LIST: + logger.info(f" * {model_name}") return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b4712aec..93746eb6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ -import os +import os # TODO(cyril): remove os +from pathlib import Path import napari import numpy as np @@ -6,7 +7,7 @@ from skimage.filters import threshold_otsu from tifffile import imread, imwrite -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -289,18 +290,13 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): ndarray Label image with artefacts labelled and small artefacts removed. """ - if not is_labeled: - # find all the connected components in the artefacts image - labels = ndimage.label(artefacts)[0] - else: - labels = artefacts + labels = ndimage.label(artefacts)[0] if not is_labeled else artefacts # remove the small components labels_i, counts = np.unique(labels, return_counts=True) labels_i = labels_i[counts > min_size] labels_i = labels_i[labels_i > 0] - artefacts = np.where(np.isin(labels, labels_i), labels, 0) - return artefacts + return np.where(np.isin(labels, labels_i), labels, 0) def create_artefact_labels( @@ -388,7 +384,7 @@ def create_artefact_labels_from_folder( path_labels.sort() path_images.sort() # create the output folder - os.makedirs(path + "/artefact_neurons", exist_ok=True) + Path().mkdir(path + "/artefact_neurons", exist_ok=True) # create the artefact labels for i in range(len(path_images)): print(path_labels[i]) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 9862c3fa..4a7363b2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -12,7 +12,7 @@ from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) @@ -228,10 +228,7 @@ def relabel( print("these labels will be added") if test: viewer.close() - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer + viewer = napari.view_image(image) if viewer is None else viewer if not test: viewer.add_labels(artefact_copy, name="labels added") napari.run() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 6c5eb5c3..df00ad0b 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,3 +1,4 @@ +import contextlib import threading from functools import partial from typing import List, Optional @@ -104,12 +105,12 @@ def __call__(cls, *args, **kwargs): ################## -def handle_adjust_errors(widget, type, context, msg: str): +def handle_adjust_errors(widget, warning_type, context, msg: str): """Qt message handler that attempts to react to errors when setting the window size and resizes the main window""" pass # head = msg.split(": ")[0] - # if type == QtWarningMsg and head == "QWindowsWindow::setGeometry": + # if warning_type == QtWarningMsg and head == "QWindowsWindow::setGeometry": # logger.warning( # f"Qt resize error : {msg}\nhas been handled by attempting to resize the window" # ) @@ -332,8 +333,7 @@ def toggle_visibility(checkbox, widget): def add_label(widget, label, label_before=True, horizontal=True): if label_before: return combine_blocks(widget, label, horizontal=horizontal) - else: - return combine_blocks(label, widget, horizontal=horizontal) + return combine_blocks(label, widget, horizontal=horizontal) class ContainerWidget(QWidget): @@ -735,8 +735,7 @@ def anisotropy_zoom_factor(aniso_res): """ base = min(aniso_res) - zoom_factors = [base / res for res in aniso_res] - return zoom_factors + return [base / res for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" @@ -796,8 +795,8 @@ def _remove_layer(self, event): index = self.layer_list.findText(removed_layer.name) self.layer_list.removeItem(index) - def set_layer_type(self, type): # no @property due to Qt constraint - self.layer_type = type + def set_layer_type(self, layer_type): # no @property due to Qt constraint + self.layer_type = layer_type [self.layer_list.removeItem(i) for i in range(self.layer_list.count())] self._check_for_layers() @@ -810,7 +809,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: logger.warning("Please select a valid layer !") - return + return None return self.layer().data @@ -898,9 +897,8 @@ def check_ready(self): self.update_field_color("indianred") self.text_field.setToolTip("Mandatory field !") return False - else: - self.update_field_color(f"{napari_param_darkgrey}") - return True + self.update_field_color(f"{napari_param_darkgrey}") + return True @property def required(self): @@ -912,10 +910,9 @@ def required(self, is_required): if is_required: self.text_field.textChanged.connect(self.check_ready) else: - try: + with contextlib.suppress(TypeError): self.text_field.textChanged.disconnect(self.check_ready) - except TypeError: - pass + self.check_ready() self._required = is_required @@ -1002,22 +999,22 @@ def make_scrollable( def set_spinbox( box, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, fixed: Optional[bool] = True, ): """Args: box : QSpinBox or QDoubleSpinBox - min : minimum value, defaults to 0 - max : maximum value, defaults to 10 + min_value : minimum value, defaults to 0 + max_value : maximum value, defaults to 10 default : default value, defaults to 0 step : step value, defaults to 1 fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" - box.setMinimum(min) - box.setMaximum(max) + box.setMinimum(min_value) + box.setMaximum(max_value) box.setSingleStep(step) box.setValue(default) @@ -1028,8 +1025,8 @@ def set_spinbox( def make_n_spinboxes( class_, n: int = 2, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, parent: Optional[QWidget] = None, @@ -1040,8 +1037,8 @@ def make_n_spinboxes( Args: class_ : QSpinBox or QDoubleSpinbox n (int): number of increment counters to create - min (Optional[int]): minimum value, defaults to 0 - max (Optional[int]): maximum value, defaults to 10 + min_value (Optional[int]): minimum value, defaults to 0 + max_value (Optional[int]): maximum value, defaults to 10 default (Optional[int]): default value, defaults to 0 step (Optional[int]): step value, defaults to 1 parent: parent widget, defaults to None @@ -1052,7 +1049,7 @@ def make_n_spinboxes( boxes = [] for _i in range(n): - box = class_(min, max, default, step, parent, fixed) + box = class_(min_value, max_value, default, step, parent, fixed) boxes.append(box) return boxes @@ -1225,10 +1222,9 @@ def open_file_dialog( default_path = utils.parse_default_path(possible_paths) - f_name = QFileDialog.getOpenFileName( + return QFileDialog.getOpenFileName( widget, "Choose file", default_path, filetype ) - return f_name def open_folder_dialog( @@ -1238,10 +1234,9 @@ def open_folder_dialog( default_path = utils.parse_default_path(possible_paths) logger.info(f"Default : {default_path}") - filenames = QFileDialog.getExistingDirectory( + return QFileDialog.getExistingDirectory( widget, "Open directory", default_path + "/.." ) - return filenames def make_label(name, parent=None): # TODO update to child class @@ -1258,12 +1253,11 @@ def make_label(name, parent=None): # TODO update to child class label = QLabel(name, parent) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label else: label = QLabel(name) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label + return label def make_group(title, l=7, t=20, r=7, b=11, parent=None): diff --git a/pyproject.toml b/pyproject.toml index 81d2a788..7210af6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ exclude = [ "dist", "node_modules", "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", ] [tool.black] From 90e46566cb35b9275a6e3d8ab5dd51a8cec63b80 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:06:24 +0200 Subject: [PATCH 261/577] Small docs update --- docs/index.rst | 4 +- docs/res/code/instance_segmentation.rst | 53 +++++++++++++++++++ docs/res/code/model_instance_seg.rst | 53 ------------------- docs/res/code/plugin_convert.rst | 15 ------ docs/res/code/utils.rst | 4 -- .../code/{model_workers.rst => workers.rst} | 8 +-- docs/res/guides/custom_model_template.rst | 28 +++++++++- docs/res/guides/detailed_walkthrough.rst | 4 +- docs/res/guides/inference_module_guide.rst | 2 +- docs/res/guides/training_module_guide.rst | 2 +- napari_cellseg3d/code_models/workers.py | 28 +++++----- 11 files changed, 105 insertions(+), 96 deletions(-) create mode 100644 docs/res/code/instance_segmentation.rst delete mode 100644 docs/res/code/model_instance_seg.rst rename docs/res/code/{model_workers.rst => workers.rst} (78%) diff --git a/docs/index.rst b/docs/index.rst index 7e809fbe..46c57c08 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,8 +39,8 @@ Welcome to napari-cellseg3d's documentation! res/code/plugin_convert res/code/plugin_metrics res/code/model_framework - res/code/model_workers - res/code/model_instance_seg + res/code/workers + res/code/instance_segmentation res/code/plugin_model_inference res/code/plugin_model_training res/code/utils diff --git a/docs/res/code/instance_segmentation.rst b/docs/res/code/instance_segmentation.rst new file mode 100644 index 00000000..143560c4 --- /dev/null +++ b/docs/res/code/instance_segmentation.rst @@ -0,0 +1,53 @@ +instance_segmentation.py +=========================================== + +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::VoronoiOtsu + :members: __init__ + + +Functions +------------- + +binary_connected +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_connected + +binary_watershed +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_watershed + +volume_stats +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::volume_stats + +clear_small_objects +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::clear_small_objects + +to_instance +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_instance + +to_semantic +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_semantic diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst deleted file mode 100644 index 3b323173..00000000 --- a/docs/res/code/model_instance_seg.rst +++ /dev/null @@ -1,53 +0,0 @@ -model_instance_seg.py -=========================================== - -Classes -------------- - -InstanceMethod -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod - :members: __init__ - -ConnectedComponents -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents - :members: __init__ - -Watershed -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed - :members: __init__ - -VoronoiOtsu -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu - :members: __init__ - - -Functions -------------- - -binary_connected -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_connected - -binary_watershed -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_watershed - -volume_stats -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::volume_stats - -clear_small_objects -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::clear_small_objects - -to_instance -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_instance - -to_semantic -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_semantic diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 03944510..25006d0f 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -28,18 +28,3 @@ ThresholdUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ThresholdUtils :members: __init__ - -Functions ------------------------------------ - -save_folder -***************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_folder - -save_layer -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_layer - -show_result -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index e90ee7e0..d9fdcfa2 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -62,7 +62,3 @@ denormalize_y load_images ************************************** .. autofunction:: napari_cellseg3d.utils::load_images - -format_Warning -************************************** -.. autofunction:: napari_cellseg3d.utils::format_Warning diff --git a/docs/res/code/model_workers.rst b/docs/res/code/workers.rst similarity index 78% rename from docs/res/code/model_workers.rst rename to docs/res/code/workers.rst index 85f8da29..1f5167ad 100644 --- a/docs/res/code/model_workers.rst +++ b/docs/res/code/workers.rst @@ -1,4 +1,4 @@ -model_workers.py +workers.py =========================================== @@ -10,7 +10,7 @@ Class : LogSignal Attributes ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::LogSignal +.. autoclass:: napari_cellseg3d.code_models.workers::LogSignal :members: log_signal :noindex: @@ -24,7 +24,7 @@ Class : InferenceWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::InferenceWorker +.. autoclass:: napari_cellseg3d.code_models.workers::InferenceWorker :members: __init__, log, create_inference_dict, inference :noindex: @@ -39,6 +39,6 @@ Class : TrainingWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::TrainingWorker +.. autoclass:: napari_cellseg3d.code_models.workers::TrainingWorker :members: __init__, log, train :noindex: diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index 218795b1..a70df29b 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -3,9 +3,33 @@ Advanced : Declaring a custom model ============================================= -To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder: +.. warning:: + **WIP** : Adding new models is still a work in progress and will likely not work simply by adding the model in the plugin. + + Please `file an issue`_ if you would like to add a custom model and we will help you get it working. + +To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder:: + + class ModelTemplate_(ABC): # replace ABC with your PyTorch model class name + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) # download URL goes in pretrained_models.json + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass + .. note:: **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute -:: +.. _file an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues diff --git a/docs/res/guides/detailed_walkthrough.rst b/docs/res/guides/detailed_walkthrough.rst index 407893c2..3d06d998 100644 --- a/docs/res/guides/detailed_walkthrough.rst +++ b/docs/res/guides/detailed_walkthrough.rst @@ -1,6 +1,6 @@ .. _detailed_walkthrough: -Detailed walkthrough +Detailed walkthrough - Supervised learning =================================== The following guide will show you how to use the plugin's workflow, starting from human-labeled annotation volume, to running inference on novel volumes. @@ -109,7 +109,7 @@ of two no matter the size you choose. For optimal performance, make sure to use a power of two still, such as 64 or 120. .. important:: - Using a too large value for the size will cause memory issues. If this happens, restart napari (better handling for these situations might be added in the future). + Using a too large value for the size will cause memory issues. If this happens, restart the worker with smaller volumes. You also have the option to use data augmentation, which can improve performance and generalization. In most cases this should left enabled. diff --git a/docs/res/guides/inference_module_guide.rst b/docs/res/guides/inference_module_guide.rst index 00e67078..373e9d0d 100644 --- a/docs/res/guides/inference_module_guide.rst +++ b/docs/res/guides/inference_module_guide.rst @@ -132,4 +132,4 @@ Source code -------------------------------- * :doc:`../code/plugin_model_inference` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index 05ce69be..1038dc6d 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -128,4 +128,4 @@ Source code -------------------------------- * :doc:`../code/plugin_model_training` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c1ed62fd..e2e21363 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -61,16 +61,6 @@ logger = utils.LOGGER -""" -Writing something to log messages from outside the main thread is rather problematic (plenty of silent crashes...) -so instead, following the instructions in the guides below to have a worker with custom signals, I implemented -a custom worker function.""" - -# FutureReference(): -# https://python-forum.io/thread-31349.html -# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ -# https://napari-staging-site.github.io/guides/stable/threading.html - PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) @@ -174,12 +164,23 @@ def safe_extract( ) +""" +Writing something to log messages from outside the main thread needs specific care, +Following the instructions in the guides below to have a worker with custom signals, +a custom worker function was implemented. +""" + +# https://python-forum.io/thread-31349.html +# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ +# https://napari-staging-site.github.io/guides/stable/threading.html + + class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_ + Separate from Worker instances as indicated `on this post`_ - .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect """ # TODO link ? log_signal = Signal(str) @@ -196,6 +197,9 @@ def __init__(self): super().__init__() +# TODO(cyril): move inference and training workers to separate files + + @dataclass class InferenceResult: """Class to record results of a segmentation job""" From c4a6fe326a59e5f6b6f1951a34e83717588b238b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:24:43 +0200 Subject: [PATCH 262/577] Testing fix --- napari_cellseg3d/code_models/instance_segmentation.py | 5 ++--- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 5de7ab0c..2240e3bd 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -102,11 +102,10 @@ def _make_list_from_channels( image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - return [image] - return None + return [image] def run_method_on_channels(self, image): - image_list = self._make_list_from_channels(image) # FIXME rename + image_list = self._make_list_from_channels(image) result = np.array([self.run_method(im) for im in image_list]) return result.squeeze() diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 7235bd61..cb5ef6d8 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,7 @@ def __init__( num_classes=num_classes, ) - # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): From f76d2bc831b427af61482edc34f604ce1eacf167 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:59:05 +0200 Subject: [PATCH 263/577] Fixed multithread testing (locally) --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/_tests/test_models.py | 14 +- .../_tests/test_plugin_inference.py | 29 ++-- napari_cellseg3d/_tests/test_training.py | 27 ++-- .../code_plugins/plugin_model_inference.py | 125 ++++++++++-------- .../code_plugins/plugin_model_training.py | 108 ++++++++------- 6 files changed, 158 insertions(+), 146 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 88a67ae2..fa6905d5 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -9,6 +9,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35174b85..4852f651 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -52,7 +52,7 @@ def test_crf(qtbot): mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 - crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) + crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) @@ -60,20 +60,20 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=True ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() mock_image = mock_image[0] mock_label = mock_label[0] crf = CRFWorker(mock_image, mock_label) - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=False ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 3dafeabc..d1264218 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,10 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_inference(make_napari_viewer, qtbot): @@ -29,14 +28,16 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.setCurrentIndex(-1) - - # widget.start() # takes too long on Github Actions - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: - # blocker.connect(widget.worker.errored) - - #### assert len(viewer.layers) == 2 + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.setCurrentIndex(-1) + + widget.worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(widget.config) + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() # takes too long on Github Actions + assert widget.worker is not None + # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 921a6d26..4d558363 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,10 +2,9 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_training(make_napari_viewer, qtbot): @@ -33,15 +32,19 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) - - # widget.start() - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds. - # blocker.connect(widget.worker.errored) + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + + worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(worker_config) + + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() + assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index bb46617d..ba23e1df 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -551,64 +551,7 @@ def start(self): self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - self.model_info = config.ModelInfo( - name=self.model_choice.currentText(), - model_input_size=self.model_input_size.value(), - ) - - self.weights_config.custom = self.custom_weights_choice.isChecked() - - save_path = self.results_filewidget.text_field.text() - if not self._check_results_path(save_path): - msg = f"ERROR: please set valid results path. Current path is {save_path}" - self.log.print_and_log(msg) - logger.warning(msg) - else: - if self.results_path is None: - self.results_path = save_path - - zoom_config = config.Zoom( - enabled=self.anisotropy_wdgt.enabled(), - zoom_values=self.anisotropy_wdgt.scaling_xyz(), - ) - thresholding_config = config.Thresholding( - enabled=self.thresholding_checkbox.isChecked(), - threshold_value=self.thresholding_slider.slider_value, - ) - - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] - ) - - self.post_process_config = config.PostProcessConfig( - zoom=zoom_config, - thresholding=thresholding_config, - instance=self.instance_config, - ) - - if self.window_infer_box.isChecked(): - size = int(self.window_size_choice.currentText()) - window_config = config.SlidingWindowConfig( - window_size=size, - window_overlap=self.window_overlap_slider.slider_value, - ) - else: - window_config = config.SlidingWindowConfig() - - self.worker_config = config.InferenceWorkerConfig( - device=self.get_device(), - model_info=self.model_info, - weights_config=self.weights_config, - results_path=self.results_path, - filetype=self.filetype_choice.currentText(), - keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), - compute_stats=self.save_stats_to_csv_box.isChecked(), - post_process_config=self.post_process_config, - sliding_window_config=window_config, - use_crf=self.use_crf.isChecked(), - crf_config=self.crf_widgets.make_config(), - ) + self._set_worker_config() ##################### ##################### ##################### @@ -650,6 +593,72 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config(self, config: config.InferenceWorkerConfig): + return InferenceWorker(worker_config=config) + + def _set_worker_config(self) -> config.InferenceWorkerConfig: + self.model_info = config.ModelInfo( + name=self.model_choice.currentText(), + model_input_size=self.model_input_size.value(), + ) + + self.weights_config.custom = self.custom_weights_choice.isChecked() + + save_path = self.results_filewidget.text_field.text() + if not self._check_results_path(save_path): + msg = f"ERROR: please set valid results path. Current path is {save_path}" + self.log.print_and_log(msg) + logger.warning(msg) + else: + if self.results_path is None: + self.results_path = save_path + + zoom_config = config.Zoom( + enabled=self.anisotropy_wdgt.enabled(), + zoom_values=self.anisotropy_wdgt.scaling_xyz(), + ) + thresholding_config = config.Thresholding( + enabled=self.thresholding_checkbox.isChecked(), + threshold_value=self.thresholding_slider.slider_value, + ) + + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], + ) + + self.post_process_config = config.PostProcessConfig( + zoom=zoom_config, + thresholding=thresholding_config, + instance=self.instance_config, + ) + + if self.window_infer_box.isChecked(): + size = int(self.window_size_choice.currentText()) + window_config = config.SlidingWindowConfig( + window_size=size, + window_overlap=self.window_overlap_slider.slider_value, + ) + else: + window_config = config.SlidingWindowConfig() + + self.worker_config = config.InferenceWorkerConfig( + device=self.get_device(), + model_info=self.model_info, + weights_config=self.weights_config, + results_path=self.results_path, + filetype=self.filetype_choice.currentText(), + keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), + compute_stats=self.save_stats_to_csv_box.isChecked(), + post_process_config=self.post_process_config, + sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), + ) + return self.worker_config + def on_start(self): """Catches start signal from worker to call :py:func:`~display_status_report`""" self.display_status_report() diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 35a16799..e11eb3de 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -808,64 +808,10 @@ def start(self): self.data = None raise err - model_config = config.ModelInfo( - name=self.model_choice.currentText() - ) - - self.weights_config.path = self.weights_config.path - self.weights_config.custom = self.custom_weights_choice.isChecked() - self.weights_config.use_pretrained = ( - not self.use_transfer_choice.isChecked() - ) - - deterministic_config = config.DeterministicConfig( - enabled=self.use_deterministic_choice.isChecked(), - seed=self.box_seed.value(), - ) - - validation_percent = ( - self.validation_percent_choice.slider_value / 100 - ) - - results_path_folder = Path( - self.results_path - + f"/{model_config.name}_{utils.get_date_time()}" - ) - Path(results_path_folder).mkdir( - parents=True, exist_ok=False - ) # avoid overwrite where possible - - patch_size = [w.value() for w in self.patch_size_widgets] - - logger.debug("Loading config...") - self.worker_config = config.TrainingWorkerConfig( - device=self.get_device(), - model_info=model_config, - weights_info=self.weights_config, - train_data_dict=self.data, - validation_percent=validation_percent, - max_epochs=self.epoch_choice.value(), - loss_function=self.get_loss(self.loss_choice.currentText()), - learning_rate=float(self.learning_rate_choice.currentText()), - scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.slider_value, - validation_interval=self.val_interval_choice.value(), - batch_size=self.batch_choice.slider_value, - results_path_folder=str(results_path_folder), - sampling=self.patch_choice.isChecked(), - num_samples=self.sample_choice_slider.slider_value, - sample_size=patch_size, - do_augmentation=self.augment_choice.isChecked(), - deterministic_config=deterministic_config, - ) # TODO(cyril) continue to put params in config - self.config = config.TrainerConfig( save_as_zip=self.zip_choice.isChecked() ) - - self.log.print_and_log( - f"Saving results to : {results_path_folder}" - ) + self._set_worker_config() self.worker = TrainingWorker(config=self.worker_config) self.worker.set_download_log(self.log) @@ -895,6 +841,58 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config(self, config: config.TrainingWorkerConfig): + return TrainingWorker(config=config) + + def _set_worker_config(self) -> config.TrainingWorkerConfig: + model_config = config.ModelInfo(name=self.model_choice.currentText()) + + self.weights_config.path = self.weights_config.path + self.weights_config.custom = self.custom_weights_choice.isChecked() + self.weights_config.use_pretrained = ( + not self.use_transfer_choice.isChecked() + ) + + deterministic_config = config.DeterministicConfig( + enabled=self.use_deterministic_choice.isChecked(), + seed=self.box_seed.value(), + ) + + validation_percent = self.validation_percent_choice.slider_value / 100 + + results_path_folder = Path( + self.results_path + f"/{model_config.name}_{utils.get_date_time()}" + ) + Path(results_path_folder).mkdir( + parents=True, exist_ok=False + ) # avoid overwrite where possible + + patch_size = [w.value() for w in self.patch_size_widgets] + + logger.debug("Loading config...") + self.worker_config = config.TrainingWorkerConfig( + device=self.get_device(), + model_info=model_config, + weights_info=self.weights_config, + train_data_dict=self.data, + validation_percent=validation_percent, + max_epochs=self.epoch_choice.value(), + loss_function=self.get_loss(self.loss_choice.currentText()), + learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, + validation_interval=self.val_interval_choice.value(), + batch_size=self.batch_choice.slider_value, + results_path_folder=str(results_path_folder), + sampling=self.patch_choice.isChecked(), + num_samples=self.sample_choice_slider.slider_value, + sample_size=patch_size, + do_augmentation=self.augment_choice.isChecked(), + deterministic_config=deterministic_config, + ) # TODO(cyril) continue to put params in config + + return self.worker_config + def on_start(self): """Catches started signal from worker""" From eced2135ca1f351c7e4950ee5776424ca972bf81 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:06:02 +0200 Subject: [PATCH 264/577] Added proper tests for train/infer --- .../_tests/test_plugin_inference.py | 36 ++++++++++++++----- napari_cellseg3d/_tests/test_training.py | 34 ++++++++++++------ napari_cellseg3d/code_models/workers.py | 4 +-- .../code_plugins/plugin_model_inference.py | 8 +++-- .../code_plugins/plugin_model_training.py | 10 ++++-- 5 files changed, 67 insertions(+), 25 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index d1264218..04305082 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -4,7 +4,10 @@ from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.code_plugins.plugin_model_inference import ( + InferenceResult, + Inferer, +) from napari_cellseg3d.config import MODEL_LIST @@ -28,16 +31,31 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(widget.config) - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() # takes too long on Github Actions - assert widget.worker is not None + assert widget.worker_config is not None + assert widget.model_info is not None + worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None + assert worker.config.model_info is not None + worker.config.layer = viewer.layers[0].data + assert worker.config.layer is not None + worker.log_parameters() + + res = next(worker.inference()) + assert isinstance(res, InferenceResult) + assert res.result.shape == (6, 6, 6) + + # def on_error(e): + # print(e) + # assert False + # with qtbot.waitSignal( + # signal=worker.finished, timeout=10000, raising=True + # ) as blocker: + # worker.error_signal.connect(on_error) + # blocker.connect(worker.errored) + # worker.inference() # takes too long on Github Actions # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 4d558363..080df419 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -3,7 +3,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d.code_plugins.plugin_model_training import ( + Trainer, + TrainingReport, +) from napari_cellseg3d.config import MODEL_LIST @@ -32,19 +35,30 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(worker_config) - - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() - assert widget.worker is not None + worker = widget._create_worker_from_config(worker_config) + worker.config.train_data_dict = [{"image": im_path, "label": im_path}] + worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.log_parameters() + res = next(worker.train()) + + assert isinstance(res, TrainingReport) + + # def on_error(e): + # print(e) + # assert False + # + # with qtbot.waitSignal( + # signal=widget.worker.finished, timeout=10000, raising=True + # ) as blocker: + # blocker.connect(widget.worker.errored) + # widget.worker.error_signal.connect(on_error) + # widget.worker.train() + # assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index e2e21363..6dd32c80 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -965,7 +965,7 @@ class TrainingWorker(GeneratorWorker): def __init__( self, - config: config.TrainingWorkerConfig, + worker_config: config.TrainingWorkerConfig, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train` @@ -1012,7 +1012,7 @@ def __init__( self._weight_error = False ############################################# - self.config = config + self.config = worker_config self.train_files = [] self.val_files = [] diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index ba23e1df..1d8c0620 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -593,8 +593,12 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.InferenceWorkerConfig): - return InferenceWorker(worker_config=config) + def _create_worker_from_config( + self, worker_config: config.InferenceWorkerConfig + ): + if isinstance(worker_config, config.InfererConfig): + raise TypeError("Please provide a valid worker config object") + return InferenceWorker(worker_config=worker_config) def _set_worker_config(self) -> config.InferenceWorkerConfig: self.model_info = config.ModelInfo( diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index e11eb3de..2a131a5f 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -841,8 +841,14 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.TrainingWorkerConfig): - return TrainingWorker(config=config) + def _create_worker_from_config( + self, worker_config: config.TrainingWorkerConfig + ): + if isinstance(config, config.TrainerConfig): + raise TypeError( + "Expected a TrainingWorkerConfig, got a TrainerConfig" + ) + return TrainingWorker(worker_config=worker_config) def _set_worker_config(self) -> config.TrainingWorkerConfig: model_config = config.ModelInfo(name=self.model_choice.currentText()) From bdb3799625487d253cea015b222f45f20671e3c6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:31:36 +0200 Subject: [PATCH 265/577] Slight coverage increase --- napari_cellseg3d/_tests/test_plugin_inference.py | 13 ++----------- napari_cellseg3d/_tests/test_training.py | 1 + napari_cellseg3d/code_models/models/model_test.py | 2 +- napari_cellseg3d/code_models/workers.py | 6 +++--- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 04305082..c437ac83 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -39,23 +39,14 @@ def test_inference(make_napari_viewer, qtbot): assert widget.worker_config is not None assert widget.model_info is not None worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data + worker.config.post_process_config.instance.enabled = True assert worker.config.layer is not None worker.log_parameters() res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) - - # def on_error(e): - # print(e) - # assert False - # with qtbot.waitSignal( - # signal=worker.finished, timeout=10000, raising=True - # ) as blocker: - # worker.error_signal.connect(on_error) - # blocker.connect(worker.errored) - # worker.inference() # takes too long on Github Actions - # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 080df419..e7f1e07b 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -43,6 +43,7 @@ def test_training(make_napari_viewer, qtbot): worker = widget._create_worker_from_config(worker_config) worker.config.train_data_dict = [{"image": im_path, "label": im_path}] worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.config.max_epochs = 1 worker.log_parameters() res = next(worker.train()) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 1ccac3da..1cb52f06 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -8,7 +8,7 @@ class TestModel(nn.Module): def __init__(self, **kwargs): super().__init__() - self.linear = nn.Linear(1, 1) + self.linear = nn.Linear(8, 8) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 6dd32c80..8ddc7921 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1425,9 +1425,9 @@ def get_loader_func(num_samples): device = self.config.device - if model_name == "test": - self.quit() - yield TrainingReport(False) + # if model_name == "test": + # self.quit() + # yield TrainingReport(False) for epoch in range(self.config.max_epochs): # self.log("\n") From e4d424aadc6da97a090e2c8a48dfa52e028d5767 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:45:47 +0200 Subject: [PATCH 266/577] Update test_plugin_inference.py --- napari_cellseg3d/_tests/test_plugin_inference.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index c437ac83..ca8e84d4 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,6 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.instance_segmentation import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import ( InferenceResult, @@ -44,6 +47,10 @@ def test_inference(make_napari_viewer, qtbot): assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data worker.config.post_process_config.instance.enabled = True + worker.config.post_process_config.instance.method = ( + INSTANCE_SEGMENTATION_METHOD_LIST["Watershed"]() + ) + assert worker.config.layer is not None worker.log_parameters() From 376b0d2d1db799b5332adc0e0992d1b78df93aee Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 17 May 2023 11:41:39 +0200 Subject: [PATCH 267/577] Set window inference to 64 for WNet --- .../code_plugins/plugin_model_inference.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 1d8c0620..74dc62e5 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -119,6 +119,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.model_choice.currentIndexChanged.connect( self._toggle_display_model_input_size ) + self.model_choice.currentIndexChanged.connect( + self._restrict_window_size_for_model + ) self.model_choice.setCurrentIndex(0) self.anisotropy_wdgt = ui.AnisotropyWidgets( @@ -150,9 +153,10 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.window_infer_box = ui.CheckBox("Use window inference") - self.window_infer_box.clicked.connect(self._toggle_display_window_size) + self.window_infer_box.toggled.connect(self._toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] + self._default_window_size = sizes_window.index("64") # ( # self.window_size_choice, # self.window_size_choice.label, @@ -167,7 +171,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(3) # set to 64 by default + self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -192,7 +196,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) - self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## @@ -299,6 +302,19 @@ def check_ready(self): return True return False + def _restrict_window_size_for_model(self): + """Sets the window size to a value that is compatible with the chosen model""" + if self.model_choice.currentText() == "WNet": + self.window_size_choice.setCurrentIndex(self._default_window_size) + self.window_size_choice.setDisabled(True) + self.window_infer_box.setChecked(True) + self.window_infer_box.setDisabled(True) + else: + self.window_size_choice.setDisabled(False) + self.window_infer_box.setDisabled(False) + self.window_infer_box.setChecked(False) + self.window_size_choice.setCurrentIndex(self._default_window_size) + def _toggle_display_model_input_size(self): if ( self.model_choice.currentText() == "SegResNet" From 934b95b63dc97ef32d5af0a821bf091a6e64777d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 17 May 2023 22:00:16 +0200 Subject: [PATCH 268/577] Update instance_segmentation.py --- napari_cellseg3d/code_models/instance_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 2240e3bd..93de0768 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -434,7 +434,7 @@ def sphericity(region): return ImageStats( volume, [region.centroid[0] for region in properties], - [region.centroid[0] for region in properties], + [region.centroid[1] for region in properties], [region.centroid[2] for region in properties], sphericity_ax, fill([volume_image.shape]), From 7156e0607fdbf6cb65f1107626221612b665b553 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 20 May 2023 09:22:52 +0200 Subject: [PATCH 269/577] Moved normalization to the correct place --- napari_cellseg3d/code_models/workers.py | 2 +- napari_cellseg3d/utils.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 8ddc7921..dd9e38e3 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -492,9 +492,9 @@ def model_output( logger.debug(f"inputs type : {inputs.dtype}") try: # outputs = model(inputs) + inputs = utils.remap_image(inputs) def model_output_wrapper(inputs): - inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 7ca29e00..90a64cfb 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -223,12 +223,18 @@ def normalize_max(image): def remap_image( - image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 + image: Union["np.ndarray", "torch.Tensor"], + new_max=100, + new_min=0, + prev_max=None, + prev_min=None, ): """Normalizes a numpy array or Tensor using the max and min value""" shape = image.shape image = image.flatten() - image = (image - image.min()) / (image.max() - image.min()) + im_max = prev_max if prev_max is not None else image.max() + im_min = prev_min if prev_min is not None else image.min() + image = (image - im_min) / (im_max - im_min) image = image * (new_max - new_min) + new_min image = image.reshape(shape) return image From b8276cfe7423affab0b6c74404769b19ee0f0d94 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 11:09:48 +0200 Subject: [PATCH 270/577] Added auto-set dims for cropping --- napari_cellseg3d/code_plugins/plugin_crop.py | 38 +++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index a27b4baa..e3ea55f5 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -3,6 +3,7 @@ import napari import numpy as np from magicgui import magicgui +from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -43,6 +44,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") + self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -112,6 +114,8 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self._build() self._toggle_second_image_io_visibility() + self._check_image_list() + self.auto_set_dims() def _toggle_second_image_io_visibility(self): crop_2nd = self.crop_second_image_choice.isChecked() @@ -132,6 +136,16 @@ def _check_image_list(self): except IndexError: return + def auto_set_dims(self): + logger.debug(self.image_layer_loader.layer_name()) + data = self.image_layer_loader.layer_data() + if data is not None: + logger.debug("auto_set_dims : {}".format(data.shape)) + if len(data.shape) == 3: + for i, box in enumerate(self.crop_size_widgets): + logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") + box.setValue(floor(data.shape[i] / 2)) + def _build(self): """Build buttons in a layout and add them to the napari Viewer""" @@ -266,9 +280,9 @@ def _start(self): except ValueError as e: logger.warning(e) logger.warning( - "Could not remove cropping layer programmatically!" + "Could not remove the previous cropping layer programmatically." ) - logger.warning("Maybe layer has been removed by user?") + # logger.warning("Maybe layer has been removed by user?") self.results_path = Path(self.results_filewidget.text_field.text()) @@ -346,7 +360,7 @@ def add_isotropic_layer( layer.data, name=f"Scaled_{layer.name}", colormap=colormap, - contrast_limits=contrast_lim, + # contrast_limits=contrast_lim, opacity=opacity, scale=self.aniso_factors, visible=visible, @@ -481,8 +495,8 @@ def set_slice( """ "Update cropped volume position""" # self._check_for_empty_layer(highres_crop_layer, highres_crop_layer.data) - logger.debug(f"axis : {axis}") - logger.debug(f"value : {value}") + # logger.debug(f"axis : {axis}") + # logger.debug(f"value : {value}") idx = int(value) scale = np.asarray(highres_crop_layer.scale) @@ -496,6 +510,20 @@ def set_slice( cropy = self._crop_size_y cropz = self._crop_size_z + if i + cropx > im1_stack.shape[0]: + cropx = im1_stack.shape[0] - i + if j + cropy > im1_stack.shape[1]: + cropy = im1_stack.shape[1] - j + if k + cropz > im1_stack.shape[2]: + cropz = im1_stack.shape[2] - k + + logger.debug(f"cropx : {cropx}") + logger.debug(f"cropy : {cropy}") + logger.debug(f"cropz : {cropz}") + logger.debug(f"i : {i}") + logger.debug(f"j : {j}") + logger.debug(f"k : {k}") + highres_crop_layer.data = im1_stack[ i : i + cropx, j : j + cropy, k : k + cropz ] From 76cc421157ec745e3b9de17b098d1316254f025e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 12:19:37 +0200 Subject: [PATCH 271/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 21 +++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 584be4d7..60c25ccc 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,27 +1,34 @@ -from pathlib import Path - import numpy as np -from tifffile import imread +from numpy.random import PCG64, Generator from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, Utilities, ) +rand_gen = Generator(PCG64(12345)) + def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) - im_path = str(Path(__file__).resolve().parent / "res/test.tif") - image = imread(im_path) - view.add_image(image) - view.add_labels(image.astype(np.uint8)) + image = rand_gen.random((10, 10, 10)).astype(np.uint8) + image_layer = view.add_image(image, name="image") + label_layer = view.add_labels(image.astype(np.uint8), name="labels") view.window.add_dock_widget(widget) + view.dims.ndisplay = 3 for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): widget.utils_choice.setCurrentIndex(i) assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + + assert len(image_layer.data.shape) == 3 + assert len(label_layer.data.shape) == 3 widget.utils_widgets[i]._start() From a890e8a9aeba2da34f759753927851fa60df33e0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 15:50:18 +0200 Subject: [PATCH 272/577] More WNet - Added experimental .pt loading for jit models - More CRF tests - Optimized WNet by loading inference only --- napari_cellseg3d/_tests/test_models.py | 61 ++++++++++++------ napari_cellseg3d/code_models/crf.py | 8 ++- .../code_models/model_framework.py | 2 +- .../code_models/models/model_WNet.py | 18 +++--- .../code_models/models/wnet/model.py | 19 ++++-- napari_cellseg3d/code_models/workers.py | 62 ++++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 19 +++--- .../dev_scripts/correct_labels.py | 12 ++-- pyproject.toml | 1 + 9 files changed, 124 insertions(+), 78 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 4852f651..c67b3cab 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -2,9 +2,14 @@ import torch from numpy.random import PCG64, Generator -from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf +from napari_cellseg3d.code_models.crf import ( + CRFWorker, + correct_shape_for_crf, + crf_batch, + crf_with_config, +) from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss -from napari_cellseg3d.config import MODEL_LIST +from napari_cellseg3d.config import MODEL_LIST, CRFConfig rand_gen = Generator(PCG64(12345)) @@ -47,7 +52,38 @@ def test_soft_ncuts_loss(): assert 0 <= res <= 1 -def test_crf(qtbot): +def test_crf_batch(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_batch( + np.array([mock_image, mock_image, mock_image]), + np.array([mock_label, mock_label, mock_label]), + sa=config.sa, + sb=config.sb, + sg=config.sg, + w1=config.w1, + w2=config.w2, + ) + + assert isinstance(result, np.ndarray) + assert result.shape == (3, 2, dims, dims, dims) + + +def test_crf_config(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_with_config(mock_image, mock_label, config) + assert isinstance(result, np.ndarray) + assert result.shape == mock_label.shape + + +def test_crf_worker(qtbot): dims = 8 mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) @@ -60,20 +96,5 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=True - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() - - mock_image = mock_image[0] - mock_label = mock_label[0] - - crf = CRFWorker(mock_image, mock_label) - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=False - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() + result = next(crf._run_crf_job()) + on_yield(result) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 8c311059..b362246a 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,6 +54,8 @@ def correct_shape_for_crf(image, desired_dims=4): + logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}") + logger.debug(f"Image shape: {image.shape}") if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( @@ -62,6 +64,7 @@ def correct_shape_for_crf(image, desired_dims=4): image = np.squeeze(image, axis=0) elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) + logger.debug(f"Corrected image shape: {image.shape}") return image @@ -210,9 +213,12 @@ def _run_crf_job(self): if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - im = correct_shape_for_crf(self.labels[i]) + im = correct_shape_for_crf(self.images[i]) prob = correct_shape_for_crf(self.labels[i]) + logger.debug(f"image shape : {im.shape}") + logger.debug(f"labels shape : {prob.shape}") + yield crf( im, prob, diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 60644916..0296e0cf 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -281,7 +281,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth)", + filetype="Weights file (*.pth, *.pt)", ) if file[0] == self._default_weights_folder: return diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index cb5ef6d8..62142e73 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,8 +1,8 @@ # local -from napari_cellseg3d.code_models.models.wnet.model import WNet +from napari_cellseg3d.code_models.models.wnet.model import WNet_encoder -class WNet_(WNet): +class WNet_(WNet_encoder): use_default_training = False weights_file = "wnet.pth" @@ -24,13 +24,13 @@ def __init__( # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") - def forward(self, x): - """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference. - """ - return self.forward_encoder(x) - # enc = self.forward_encoder(x) - # return self.forward_decoder(enc) + # def forward(self, x): + # """Forward ENCODER pass of the W-Net model. + # Done this way to allow inference on the encoder only when called by sliding_window_inference. + # """ + # return self.forward_encoder(x) + # # enc = self.forward_encoder(x) + # # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 585ea0dd..a23084d0 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -16,6 +16,19 @@ ] +class WNet_encoder(nn.Module): + """WNet with encoder only.""" + + def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + super().__init__() + self.device = device + self.encoder = UNet(device, in_channels, num_classes, encoder=True) + + def forward(self, x): + """Forward pass of the W-Net model.""" + return self.forward_encoder(x) + + class WNet(nn.Module): """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. The model performs unsupervised segmentation of 3D images. @@ -36,13 +49,11 @@ def forward(self, x): def forward_encoder(self, x): """Forward pass of the encoder part of the W-Net model.""" - enc = self.encoder(x) - return enc + return self.encoder(x) def forward_decoder(self, enc): """Forward pass of the decoder part of the W-Net model.""" - dec = self.decoder(enc) - return dec + return self.decoder(enc) class UNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index dd9e38e3..8b3da42d 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -820,41 +820,43 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config - - # try: - self.log("Instantiating model...") - model = model_class( # FIXME test if works - input_img_size=[dims, dims, dims], - device=self.config.device, - num_classes=self.config.model_info.num_classes, - ) - # try: - model = model.to(self.config.device) - # except Exception as e: - # self.raise_error(e, "Issue loading model to device") - # logger.debug(f"model : {model}") - if model is None: - raise ValueError("Model is None") + if Path(weights_config.path).suffix == ".pt": + model = torch.jit.load(weights_config.path) # try: - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path else: - self.downloader.download_weights( - model_name, - model_class.weights_file, - ) - weights = str( - PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[dims, dims, dims], + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, + ) + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + ) - model.load_state_dict( # note that this is redefined in WNet_ - torch.load( - weights, - map_location=self.config.device, + model.load_state_dict( # note that this is redefined in WNet_ + torch.load( + weights, + map_location=self.config.device, + ) ) - ) - self.log("Done") + self.log("Done") # except Exception as e: # self.raise_error(e, "Issue loading weights") # except Exception as e: diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index e3ea55f5..74691e1f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,9 +1,9 @@ +from math import floor from pathlib import Path import napari import numpy as np from magicgui import magicgui -from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -44,7 +44,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") - self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) + self.image_layer_loader.layer_list.currentIndexChanged.connect( + self.auto_set_dims + ) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -140,10 +142,12 @@ def auto_set_dims(self): logger.debug(self.image_layer_loader.layer_name()) data = self.image_layer_loader.layer_data() if data is not None: - logger.debug("auto_set_dims : {}".format(data.shape)) + logger.debug(f"auto_set_dims : {data.shape}") if len(data.shape) == 3: for i, box in enumerate(self.crop_size_widgets): - logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") + logger.debug( + f"setting dim {i} to {floor(data.shape[i]/2)}" + ) box.setValue(floor(data.shape[i] / 2)) def _build(self): @@ -433,9 +437,8 @@ def _add_crop_sliders( box.value() for box in self.crop_size_widgets ] ############# - dims = [self._x, self._y, self._z] - [logger.debug(f"{dim}") for dim in dims] - logger.debug("SET DIMS ATTEMPT") + # [logger.debug(f"{dim}") for dim in dims] + # logger.debug("SET DIMS ATTEMPT") # if not self.create_new_layer.isChecked(): # self._x = x # self._y = y @@ -451,6 +454,8 @@ def _add_crop_sliders( # define crop sizes and boundaries for the image crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z] + # [logger.debug(f"{crop}") for crop in crop_sizes] + # logger.debug("SET CROP ATTEMPT") for i in range(len(crop_sizes)): if crop_sizes[i] > im1_stack.shape[i]: diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 4a7363b2..f413812d 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -363,9 +363,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") -# -# image_path = str(im_path / "volumes/images.tif") -# gt_labels_path = str(im_path / "labels/testing_im.tif") -# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) +if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/somatomotor") + + image_path = str(im_path / "volumes/c1images.tif") + gt_labels_path = str(im_path / "labels/c1labels.tif") + relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) diff --git a/pyproject.toml b/pyproject.toml index 7210af6e..87cc2e1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ docs = [ test = [ "pytest", "pytest_qt", + "pytest-cov", "coverage", "tox", "twine", From c135c41e89c6397824cc1df99c5ee376cec93a85 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:12:07 +0200 Subject: [PATCH 273/577] Update crf test/deps for testing --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/_tests/test_models.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index fa6905d5..0911e358 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,6 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions + python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index c67b3cab..ec7462db 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -68,7 +68,6 @@ def test_crf_batch(): w2=config.w2, ) - assert isinstance(result, np.ndarray) assert result.shape == (3, 2, dims, dims, dims) @@ -79,7 +78,6 @@ def test_crf_config(): config = CRFConfig() result = crf_with_config(mock_image, mock_label, config) - assert isinstance(result, np.ndarray) assert result.shape == mock_label.shape @@ -91,7 +89,6 @@ def test_crf_worker(qtbot): crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): - assert isinstance(result, np.ndarray) assert len(result.shape) == 4 assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] From 088f6215c1b13c0b8a3bfee6bad11c1d5ba76963 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:20:30 +0200 Subject: [PATCH 274/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0911e358..d09be5f0 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -87,6 +86,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From 0fc7fdce1c8b3b5b8aa21cf01fd7d3ec0fa41cba Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:34:33 +0200 Subject: [PATCH 275/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d09be5f0..d36e03a3 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,6 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -86,7 +87,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From 1c0b1b533989980722b1f95ec6123692361c4574 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:28 +0200 Subject: [PATCH 276/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 0a7c07f0..ee033e59 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 3edd4c7e154d3992c0e441e14560e999283659c7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:45 +0200 Subject: [PATCH 277/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d36e03a3..60bc5505 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox From ef1d37be30d84aa5887b0d8d8075c6e403c9aced Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:50:44 +0200 Subject: [PATCH 278/577] Trying to fix tox install of pydensecrf --- .github/workflows/test_and_deploy.yml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 60bc5505..e9a66ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions -# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/tox.ini b/tox.ini index ee033e59..ba3e8805 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 85863d15700c8eb65a86710d7ca954ff8f77e43d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:23:51 +0200 Subject: [PATCH 279/577] Added experimental ONNX support for inference --- .../code_models/model_framework.py | 15 ++++---- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 34 ++++++++++++++++++- .../code_plugins/plugin_model_inference.py | 14 +++++++- pyproject.toml | 8 +++++ 5 files changed, 64 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 0296e0cf..f379ccb8 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -273,6 +273,14 @@ def get_available_models(): # self.lbl_model_path.setText(self.model_path) # # self.update_default() + def _update_weights_path(self, file): + if file[0] == self._default_weights_folder: + return + if file is not None and file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) + def _load_weights_path(self): """Show file dialog to set :py:attr:`model_path`""" @@ -283,12 +291,7 @@ def _load_weights_path(self): [self._default_weights_folder], filetype="Weights file (*.pth, *.pt)", ) - if file[0] == self._default_weights_folder: - return - if file is not None and file[0] != "": - self.weights_config.path = file[0] - self.weights_filewidget.text_field.setText(file[0]) - self._default_weights_folder = str(Path(file[0]).parent) + self._update_weights_path(file) @staticmethod def get_device(show=True): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index a23084d0..f98829bb 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -26,7 +26,7 @@ def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): def forward(self, x): """Forward pass of the W-Net model.""" - return self.forward_encoder(x) + return self.encoder(x) class WNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 8b3da42d..be88c835 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,6 +199,34 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files +class ONNXModelWrapper(torch.nn.Module): + """Class to replace torch model if ONNX is used""" + def __init__(self, file_location): + super().__init__() + try: + import onnx + import onnxruntime as ort + except ImportError as e: + logger.error("ONNX is not installed but ONNX model was loaded") + logger.error(e) + msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" + logger.error(msg) + raise ImportError(msg) + + self.ort_session = ort.InferenceSession( + file_location, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + + def forward(self, modeL_input): + outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + return torch.tensor(outputs[0]) + + def eval(self): + return True + + def to(self, device): + return True @dataclass class InferenceResult: @@ -821,9 +849,13 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config if Path(weights_config.path).suffix == ".pt": + self.log("Instantiating PyTorch jit model...") model = torch.jit.load(weights_config.path) # try: - else: + elif Path(weights_config.path).suffix == ".onnx": + self.log("Instantiating ONNX model...") + model = ONNXModelWrapper(weights_config.path) + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 74dc62e5..599ec5b3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING - +from pathlib import Path import numpy as np import pandas as pd @@ -348,6 +348,18 @@ def _toggle_display_window_size(self): """Show or hide window size choice depending on status of self.window_infer_box""" ui.toggle_visibility(self.window_infer_box, self.window_infer_params) + def _load_weights_path(self): + """Show file dialog to set :py:attr:`model_path`""" + + # logger.debug(self._default_weights_folder) + + file = ui.open_file_dialog( + self, + [self._default_weights_folder], + filetype="Weights file (*.pth, *.pt, *.onnx)", + ) + self._update_weights_path(file) + def _build(self): """Puts all widgets in a layout and adds them to the napari Viewer""" diff --git a/pyproject.toml b/pyproject.toml index 87cc2e1d..2783761e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,3 +118,11 @@ test = [ "twine", "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] +onnx-cpu = [ + "onnx", + "onnxruntime" +] +onnx-gpu = [ + "onnx", + "onnxruntime-gpu" +] From c901733834c48aee34153d1db0fcbbcf9e06a332 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:47:48 +0200 Subject: [PATCH 280/577] Updated WNet for ONNX conversion --- .../code_models/models/wnet/model.py | 59 +++++++++++-------- napari_cellseg3d/code_models/workers.py | 9 ++- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index f98829bb..23584b30 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -59,18 +59,33 @@ def forward_decoder(self, enc): class UNet(nn.Module): """Half of the W-Net model, based on the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, encoder=True): + def __init__( + self, device, in_channels, out_channels, encoder=True, dropout=0.65 + ): super(UNet, self).__init__() self.device = device - self.in_b = InBlock(device, in_channels, 64) - self.conv1 = Block(device, 64, 128) - self.conv2 = Block(device, 128, 256) - self.conv3 = Block(device, 256, 512) - self.bot = Block(device, 512, 1024) - self.deconv1 = Block(device, 1024, 512) - self.deconv2 = Block(device, 512, 256) - self.deconv3 = Block(device, 256, 128) - self.out_b = OutBlock(device, 128, out_channels) + self.max_pool = nn.MaxPool3d(2) + self.in_b = InBlock(device, in_channels, 64, dropout=dropout) + self.conv1 = Block(device, 64, 128, dropout=dropout) + self.conv2 = Block(device, 128, 256, dropout=dropout) + self.conv3 = Block(device, 256, 512, dropout=dropout) + self.bot = Block(device, 512, 1024, dropout=dropout) + self.deconv1 = Block(device, 1024, 512, dropout=dropout) + self.conv_trans1 = nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + ) + self.deconv2 = Block(device, 512, 256, dropout=dropout) + self.conv_trans2 = nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + ) + self.deconv3 = Block(device, 256, 128, dropout=dropout) + self.conv_trans3 = nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + ) + self.out_b = OutBlock(device, 128, out_channels, dropout=dropout) + self.conv_trans_out = nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + ) self.sm = nn.Softmax(dim=1).to(device) self.encoder = encoder @@ -78,17 +93,15 @@ def __init__(self, device, in_channels, out_channels, encoder=True): def forward(self, x): """Forward pass of the U-Net model.""" in_b = self.in_b(x.to(self.device)) - c1 = self.conv1(nn.MaxPool3d(2)(in_b)) - c2 = self.conv2(nn.MaxPool3d(2)(c1)) - c3 = self.conv3(nn.MaxPool3d(2)(c2)) - x = self.bot(nn.MaxPool3d(2)(c3)) + c1 = self.conv1(self.max_pool(in_b)) + c2 = self.conv2(self.max_pool(c1)) + c3 = self.conv3(self.max_pool(c2)) + x = self.bot(self.max_pool(c3)) x = self.deconv1( torch.cat( [ c3, - nn.ConvTranspose3d( - 1024, 512, 2, stride=2, device=self.device - )(x), + self.conv_trans1(x), ], dim=1, ) @@ -97,9 +110,7 @@ def forward(self, x): torch.cat( [ c2, - nn.ConvTranspose3d( - 512, 256, 2, stride=2, device=self.device - )(x), + self.conv_trans2(x), ], dim=1, ) @@ -108,9 +119,7 @@ def forward(self, x): torch.cat( [ c1, - nn.ConvTranspose3d( - 256, 128, 2, stride=2, device=self.device - )(x), + self.conv_trans3(x), ], dim=1, ) @@ -119,9 +128,7 @@ def forward(self, x): torch.cat( [ in_b, - nn.ConvTranspose3d( - 128, 64, 2, stride=2, device=self.device - )(x), + self.conv_trans_out(x), ], dim=1, ) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index be88c835..bf6b8542 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -200,7 +200,7 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files class ONNXModelWrapper(torch.nn.Module): - """Class to replace torch model if ONNX is used""" + """Class to replace torch model by ONNX Runtime session""" def __init__(self, file_location): super().__init__() try: @@ -219,14 +219,17 @@ def __init__(self, file_location): ) def forward(self, modeL_input): + """Wraps ONNX output in a torch tensor""" outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) return torch.tensor(outputs[0]) def eval(self): - return True + """Dummy function to replace model.eval()""" + pass def to(self, device): - return True + """Dummy function to replace model.to(device)""" + pass @dataclass class InferenceResult: From 34d7fe6d4ddffd14b17885791b605d82fbf179c2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:56:45 +0200 Subject: [PATCH 281/577] Added dropout param --- .../code_models/models/wnet/model.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 23584b30..3416acb1 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -141,17 +141,17 @@ def forward(self, x): class InBlock(nn.Module): """Input block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(InBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -163,19 +163,19 @@ def forward(self, x): class Block(nn.Module): """Basic block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(Block, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), nn.Conv3d(in_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.Conv3d(out_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -187,21 +187,21 @@ def forward(self, x): class OutBlock(nn.Module): """Output block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(OutBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, out_channels, 1, device=device), ).to(device) def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) + return self.module(x.to(self.device)) \ No newline at end of file From 800830b0aadb293c162dc26def8379ff69efb022 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 31 May 2023 16:13:42 +0200 Subject: [PATCH 282/577] Minor fixes in training --- napari_cellseg3d/code_models/workers.py | 8 ++++---- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 +++- napari_cellseg3d/interface.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index bf6b8542..c67ea523 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1531,13 +1531,13 @@ def get_loader_func(num_samples): or epoch + 1 == self.config.max_epochs ): model.eval() + self.log("Performing validation...") with torch.no_grad(): for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) - self.log("Performing validation...") try: with torch.no_grad(): val_outputs = sliding_window_inference( @@ -1606,8 +1606,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1620,7 +1620,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 2a131a5f..3e666dcc 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -169,6 +169,8 @@ def __init__( self.validation_values = [] # self.model_choice.setCurrentIndex(0) + wnet_index = self.model_choice.findText("WNet") + self.model_choice.removeItem(wnet_index) ################################ # interface @@ -813,7 +815,7 @@ def start(self): ) self._set_worker_config() - self.worker = TrainingWorker(config=self.worker_config) + self.worker = TrainingWorker(worker_config=self.worker_config) self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index df00ad0b..e5b189ef 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1235,7 +1235,7 @@ def open_folder_dialog( logger.info(f"Default : {default_path}") return QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + "/.." + widget, "Open directory", default_path # + "/.." ) From 69a25e54b48bccf92b102504083b3d22170606c4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:31:23 +0200 Subject: [PATCH 283/577] Fix weights file extension in inference + coverage - Remove unused scripts - More tests - Fixed weights type in inference --- .coveragerc | 7 + .gitignore | 1 + napari_cellseg3d/_tests/test_dock_widget.py | 1 + .../_tests/test_labels_correction.py | 8 +- .../_tests/test_plugin_inference.py | 2 + napari_cellseg3d/_tests/test_plugins.py | 21 ++ napari_cellseg3d/_tests/test_utils.py | 29 ++- .../code_models/model_framework.py | 28 +-- .../code_models/models/wnet/crf.py | 112 --------- napari_cellseg3d/code_plugins/plugin_crf.py | 6 +- .../code_plugins/plugin_metrics.py | 2 +- .../code_plugins/plugin_model_inference.py | 8 +- napari_cellseg3d/dev_scripts/convert.py | 26 -- napari_cellseg3d/dev_scripts/drafts.py | 15 -- .../dev_scripts/evaluate_labels.py | 2 +- .../extract_extra_channels_labels.py | 144 ----------- napari_cellseg3d/dev_scripts/view_brain.py | 8 - napari_cellseg3d/dev_scripts/view_sample.py | 29 --- .../dev_scripts/weight_conversion.py | 234 ------------------ napari_cellseg3d/interface.py | 6 +- napari_cellseg3d/utils.py | 2 +- tox.ini | 4 +- 22 files changed, 75 insertions(+), 620 deletions(-) create mode 100644 .coveragerc create mode 100644 napari_cellseg3d/_tests/test_plugins.py delete mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py delete mode 100644 napari_cellseg3d/dev_scripts/convert.py delete mode 100644 napari_cellseg3d/dev_scripts/drafts.py delete mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py delete mode 100644 napari_cellseg3d/dev_scripts/view_brain.py delete mode 100644 napari_cellseg3d/dev_scripts/view_sample.py delete mode 100644 napari_cellseg3d/dev_scripts/weight_conversion.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..038f3d5a --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[report] +exclude_lines = + if __name__ == .__main__.: + +[run] +omit = + napari_cellseg3d/setup.py diff --git a/.gitignore b/.gitignore index df67a187..7460d861 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,4 @@ notebooks/instance_test.ipynb !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif +cov.syspath.txt diff --git a/napari_cellseg3d/_tests/test_dock_widget.py b/napari_cellseg3d/_tests/test_dock_widget.py index 7737e540..8063c92b 100644 --- a/napari_cellseg3d/_tests/test_dock_widget.py +++ b/napari_cellseg3d/_tests/test_dock_widget.py @@ -11,6 +11,7 @@ def test_prepare(make_napari_viewer): viewer = make_napari_viewer() viewer.add_image(image) widget = Datamanager(viewer) + viewer.window.add_dock_widget(widget) widget.prepare(path_image, ".tif", "", False) diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..b4f13238 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -37,16 +37,16 @@ def test_correct_labels(): ) -def test_relabel(make_napari_viewer): - viewer = make_napari_viewer() +def test_relabel(): cl.relabel( str(image_path), str(labels_path), go_fast=True, - viewer=viewer, test=True, ) def test_evaluate_model_performance(): - el.evaluate_model_performance(labels, labels, print_details=True) + el.evaluate_model_performance( + labels, labels, print_details=True, visualize=False + ) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index ca8e84d4..1ae83102 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -57,3 +57,5 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) + + widget.on_yield(res) diff --git a/napari_cellseg3d/_tests/test_plugins.py b/napari_cellseg3d/_tests/test_plugins.py new file mode 100644 index 00000000..c58d26af --- /dev/null +++ b/napari_cellseg3d/_tests/test_plugins.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from napari_cellseg3d import plugins +from napari_cellseg3d.code_plugins import plugin_metrics as m + + +def test_all_plugins_import(make_napari_viewer): + plugins.napari_experimental_provide_dock_widget() + + +def test_plugin_metrics(make_napari_viewer): + viewer = make_napari_viewer() + w = m.MetricsUtils(viewer=viewer, parent=None) + viewer.window.add_dock_widget(w) + + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + labels_path = im_path + + w.image_filewidget.text_field = im_path + w.labels_filewidget.text_field = labels_path + w.compute_dice() diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 0b28183d..dc680b35 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,14 +1,15 @@ -import os from functools import partial +from pathlib import Path import numpy as np import torch from napari_cellseg3d import utils +from napari_cellseg3d.dev_scripts import thread_test def test_fill_list_in_between(): - list = [1, 2, 3, 4, 5, 6] + test_list = [1, 2, 3, 4, 5, 6] res = [ 1, "", @@ -30,11 +31,11 @@ def test_fill_list_in_between(): "", ] - assert utils.fill_list_in_between(list, 2, "") == res + assert utils.fill_list_in_between(test_list, 2, "") == res fill = partial(utils.fill_list_in_between, n=2, fill_value="") - assert fill(list) == res + assert fill(test_list) == res def test_align_array_sizes(): @@ -109,11 +110,19 @@ def test_normalize_x(): def test_parse_default_path(): - user_path = os.path.expanduser("~") - assert utils.parse_default_path([None]) == user_path + user_path = Path().home() + assert utils.parse_default_path([None]) == str(user_path) - path = ["C:/test/test", None, None] - assert utils.parse_default_path(path) == "C:/test/test" + test_path = "C:/test/test" + path = [test_path, None, None] + assert utils.parse_default_path(path) == test_path - path = ["C:/test/test", None, None, "D:/very/long/path/what/a/bore", ""] - assert utils.parse_default_path(path) == "D:/very/long/path/what/a/bore" + long_path = "D:/very/long/path/what/a/bore/ifonlytherewassomethingtohelpmenottypeitiallthetime" + path = [test_path, None, None, long_path, ""] + assert utils.parse_default_path(path) == long_path + + +def test_thread_test(make_napari_viewer): + viewer = make_napari_viewer() + w = thread_test.create_connected_widget(viewer) + viewer.window.add_dock_widget(w) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index f379ccb8..ddd9cd28 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -289,7 +289,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt)", + file_extension="Weights file (*.pth)", ) self._update_weights_path(file) @@ -311,31 +311,5 @@ def empty_cuda_cache(self): torch.cuda.empty_cache() logger.info("Attempt complete : Cache emptied") - # def update_default(self): # TODO add custom models - # """Update default path for smoother file dialogs, here with :py:attr:`~model_path` included""" - # - # if len(self.images_filepaths) != 0: - # from_images = str(Path(self.images_filepaths[0]).parent) - # else: - # from_images = None - # - # if len(self.labels_filepaths) != 0: - # from_labels = str(Path(self.labels_filepaths[0]).parent) - # else: - # from_labels = None - # - # possible_paths = [ - # path - # for path in [ - # from_images, - # from_labels, - # # self.model_path, - # self.results_path, - # ] - # if path is not None - # ] - # self._default_folders = possible_paths - # update if model_path is used again - def _build(self): raise NotImplementedError("Should be defined in children classes") diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py deleted file mode 100644 index 004db3a1..00000000 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Implements the CRF post-processing step for the W-Net. -Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - -Also uses research from: -Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials -Philipp Krähenbühl and Vladlen Koltun -NIPS 2011 - -Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. -""" - -import numpy as np -import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - create_pairwise_bilateral, - create_pairwise_gaussian, - unary_from_softmax, -) - -__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" -__credits__ = [ - "Yves Paychère", - "Colin Hofmann", - "Cyril Achard", - "Philipp Krähenbühl", - "Vladlen Koltun", - "Liang-Chieh Chen", - "George Papandreou", - "Iasonas Kokkinos", - "Kevin Murphy", - "Alan L. Yuille", - "Xide Xia", - "Brian Kulis", - "Lucas Beyer", -] - - -def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): - """CRF post-processing step for the W-Net, applied to a batch of images. - - Args: - images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. - probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. - """ - - return np.stack( - [ - crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) - for i in range(images.shape[0]) - ], - axis=0, - ) - - -def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): - """Implements the CRF post-processing step for the W-Net. - Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - Implemented using the pydensecrf library. - - Args: - image (np.ndarray): Array of shape (C, H, W, D) containing the input image. - prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. - """ - d = dcrf.DenseCRF( - image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] - ) - # print(f"Image shape : {image.shape}") - # print(f"Prob shape : {prob.shape}") - # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels - - # Get unary potentials from softmax probabilities - U = unary_from_softmax(prob) - d.setUnaryEnergy(U) - - # Generate pairwise potentials - featsGaussian = create_pairwise_gaussian( - sdims=(sg, sg, sg), shape=image.shape[1:] - ) # image.shape) - featsBilateral = create_pairwise_bilateral( - sdims=(sa, sa, sa), - schan=tuple([sb for i in range(image.shape[0])]), - img=image, - chdim=-1, - ) - - # Add pairwise potentials to the CRF - compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( - [1 for i in range(prob.shape[0])] - # , dtype=np.float32 - ) - d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) - d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) - - # Run inference - Q = d.inference(n_iter) - - return np.array(Q).reshape( - (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) - ) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index d8407a0f..76194e87 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial from pathlib import Path @@ -277,7 +278,10 @@ def _on_start(self): def _on_finish(self): self.worker = None - self.start_button.setText("Start") + with contextlib.suppress(RuntimeError): + self.start_button.setText("Start") + + # should only happen when testing def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 2a6e713c..1dc5e7de 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -23,7 +23,7 @@ class MetricsUtils(BasePluginFolder): """Plugin to evaluate metrics between two sets of labels, ground truth and prediction""" - def __init__(self, viewer: "napari.viewer.Viewer", parent): + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): """Creates a MetricsUtils widget for computing and plotting dice metrics between labels. Args: viewer: viewer to display the widget in diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 599ec5b3..256cffa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING -from pathlib import Path + import numpy as np import pandas as pd @@ -171,7 +171,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default + self.window_size_choice.setCurrentIndex( + self._default_window_size + ) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -356,7 +358,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt, *.onnx)", + file_extension="Weights file (*.pth *.pt *.onnx)", ) self._update_weights_path(file) diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py deleted file mode 100644 index 641de627..00000000 --- a/napari_cellseg3d/dev_scripts/convert.py +++ /dev/null @@ -1,26 +0,0 @@ -import glob -import os - -import numpy as np -from tifffile import imread, imwrite - -# input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" -# output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" - -input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/labels" -output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/lab_sem" - -filenames = [] -paths = [] -filetype = ".tif" -for filename in glob.glob(os.path.join(input_seg_path, "*" + filetype)): - paths.append(filename) - filenames.append(os.path.basename(filename)) - # print(os.path.basename(filename)) -for file in paths: - image = imread(file) - - image[image >= 1] = 1 - image = image.astype(np.uint16) - - imwrite(output_seg_path + "/" + os.path.basename(file), image) diff --git a/napari_cellseg3d/dev_scripts/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py deleted file mode 100644 index cdd02256..00000000 --- a/napari_cellseg3d/dev_scripts/drafts.py +++ /dev/null @@ -1,15 +0,0 @@ -import napari -import numpy as np -from magicgui import magicgui -from napari.types import ImageData, LabelsData - - -@magicgui(call_button="Run Threshold") -def threshold(image: ImageData, threshold: int = 75) -> LabelsData: - """Threshold an image and return a mask.""" - return (image > threshold).astype(int) - - -viewer = napari.view_image(np.random.randint(0, 100, (64, 64))) -viewer.window.add_dock_widget(threshold) -threshold() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 26b45d3f..00bce5ec 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -134,7 +134,7 @@ def evaluate_model_performance( log.info(mean_ratio_false_pixel_artefact) if visualize: - viewer = napari.Viewer() + viewer = napari.Viewer(ndisplay=3) viewer.add_labels(labels, name="ground truth") viewer.add_labels(model_labels, name="model's labels") found_model = np.where( diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py deleted file mode 100644 index 70ee10b6..00000000 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ /dev/null @@ -1,144 +0,0 @@ -import numpy as np -from skimage.filters import threshold_otsu -from skimage.segmentation import expand_labels -from tqdm import tqdm - - -def extract_labels_from_channels( # TODO add separate channels results - nuclei_labels: np.array, - extra_channels: list, - radius: int = 4, - threshold_factor=2, - viewer=None, -): - """ - Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. - Args: - nuclei_labels (np.array): labels for the nuclei - extra_channels (list): channels arrays to extract labels from - radius: radius in which the approximation is made - - Returns: - A list of extracted labels for each extra channel - """ - labeled_channels = [] - contrasted_channels = [] - for channel in extra_channels: - channel = (channel - np.min(channel)) / ( - np.max(channel) - np.min(channel) - ) - threshold_brightness = threshold_otsu(channel) * threshold_factor - channel_contrasted = np.where( - channel > threshold_brightness, channel, 0 - ) - contrasted_channels.append(channel_contrasted) - if viewer is not None: - viewer.add_image( - channel_contrasted, - name="channel_contrasted", - colormap="viridis", - ) - for label_id in tqdm(np.unique(nuclei_labels)): - if label_id == 0: - continue - label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) - expanded = expand_labels(label_nucleus, distance=radius) - restricted = np.where(expanded != 0, nuclei_labels, 0) - overlap = np.where(restricted != label_id, restricted, 0) - - for i, channel in enumerate(contrasted_channels): - label_contrasted = np.where(expanded != 0, channel, 0) - if overlap.any() != 0: - max_labeled = 0 - for overlap_id in np.unique(overlap): - if overlap_id == 0: - continue - assigned_pixels = np.count_nonzero( - np.where(overlap == overlap_id, channel, 0) - ) - if assigned_pixels > max_labeled: - max_labeled = assigned_pixels - max_label_id = overlap_id - if label_id != max_label_id: - labeled_channels.append( - np.zeros_like(label_contrasted) - ) - else: - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels.append(labeled_channel) - if ( - np.count_nonzero(labeled_channel) > 0 - and viewer is not None - ): - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) - - cat_labels = np.zeros_like(nuclei_labels) - for labels in np.unique(labeled_channels): - if labels == 0: - continue - cat_labels += np.where(labels != 0, labels, 0) - return cat_labels - - -if __name__ == "__main__": - from pathlib import Path - - import napari - from tifffile import imread - - image_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - ) - # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - nuclei_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" - ) - extra_channels_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" - ) - extra_channels = [ - imread(str(path)) - for path in extra_channels_path.glob( - "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" - ) - ] - labels = imread(str(image_path)) - viewer = napari.Viewer() - - shift = 0 - viewer.add_image( - imread(str(nuclei_path))[ - shift : 32 + shift, shift : 32 + shift, shift : 32 + shift - ], - name="nuclei", - ) - viewer.add_labels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - [ - viewer.add_image( - channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - for channel in extra_channels - ] - - labeled_channels = extract_labels_from_channels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], - [ - c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - for c in extra_channels - ], - radius=4, - viewer=viewer, - ) - - viewer.add_labels(labeled_channels) - # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] - # expanded = expand_labels(labels, 4) - # viewer.add_labels(expanded) - napari.run() diff --git a/napari_cellseg3d/dev_scripts/view_brain.py b/napari_cellseg3d/dev_scripts/view_brain.py deleted file mode 100644 index 145d4e45..00000000 --- a/napari_cellseg3d/dev_scripts/view_brain.py +++ /dev/null @@ -1,8 +0,0 @@ -import napari -from tifffile import imread - -y = imread("/Users/maximevidal/Documents/3drawdata/wholebrain.tif") - -with napari.gui_qt(): - viewer = napari.Viewer() - viewer.add_image(y, contrast_limits=[0, 2000], multiscale=False) diff --git a/napari_cellseg3d/dev_scripts/view_sample.py b/napari_cellseg3d/dev_scripts/view_sample.py deleted file mode 100644 index 8e87f85c..00000000 --- a/napari_cellseg3d/dev_scripts/view_sample.py +++ /dev/null @@ -1,29 +0,0 @@ -import napari -from tifffile import imread - -# Visual -x = imread( - "/Users/maximevidal/Documents/trailmap/data/no-edge-validation/visual-original/volumes/images.tif" -) -y_semantic = imread( - "/Users/maximevidal/Documents/trailmap/data/testing/seg-visual1-single/image.tif" -) -y_instance = imread( - "/Users/maximevidal/Documents/trailmap/data/instance-testing/test-visual-5.tiff" -) -y_true = imread( - "/Users/maximevidal/Documents/3drawdata/visual/labels/labels.tif" -) - -# SM -# x = imread("/Users/maximevidal/Documents/trailmap/data/no-edge-validation/validation-original/volumes/c5images.tif") -# y = imread("/Users/maximevidal/Documents/trailmap/data/instance-testing/test1.tiff") -# y_true = imread("/Users/maximevidal/Documents/3drawdata/somatomotor/labels/c5labels.tif") - -with napari.gui_qt(): - viewer = napari.view_image( - x, colormap="inferno", contrast_limits=[200, 1000] - ) - viewer.add_image(y_semantic, name="semantic_predictions", opacity=0.5) - viewer.add_labels(y_instance, name="instance_predictions", seed=0.6) - viewer.add_labels(y_true, name="truth", seed=0.6) diff --git a/napari_cellseg3d/dev_scripts/weight_conversion.py b/napari_cellseg3d/dev_scripts/weight_conversion.py deleted file mode 100644 index 6cdb9c43..00000000 --- a/napari_cellseg3d/dev_scripts/weight_conversion.py +++ /dev/null @@ -1,234 +0,0 @@ -import collections -import os - -import torch - -from napari_cellseg3d.code_models.models import get_net -from napari_cellseg3d.code_models.models.unet.model import UNet3D - -# not sure this actually works when put here - - -def weight_translate(k, w): - k = key_translate(k) - if k.endswith(".weight"): - if w.dim() == 2: - w = w.t() - elif w.dim() == 1: - pass - elif w.dim() == 4: - w = w.permute(3, 2, 0, 1) - else: - assert w.dim() == 5 - w = w.permute(4, 3, 0, 1, 2) - return w - - -def key_translate(k): - k = ( - k.replace( - "conv3d/kernel:0", - "encoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization/gamma:0", - "encoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization/beta:0", - "encoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_1/kernel:0", - "encoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_1/gamma:0", - "encoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_1/beta:0", - "encoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_2/kernel:0", - "encoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_2/gamma:0", - "encoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_2/beta:0", - "encoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_3/kernel:0", - "encoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_3/gamma:0", - "encoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_3/beta:0", - "encoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_4/kernel:0", - "encoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_4/gamma:0", - "encoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_4/beta:0", - "encoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_5/kernel:0", - "encoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_5/gamma:0", - "encoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_5/beta:0", - "encoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_6/kernel:0", - "encoders.3.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_6/gamma:0", - "encoders.3.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_6/beta:0", - "encoders.3.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_7/kernel:0", - "encoders.3.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_7/gamma:0", - "encoders.3.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_7/beta:0", - "encoders.3.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_8/kernel:0", - "decoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_8/gamma:0", - "decoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_8/beta:0", - "decoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_9/kernel:0", - "decoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_9/gamma:0", - "decoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_9/beta:0", - "decoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_10/kernel:0", - "decoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_10/gamma:0", - "decoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_10/beta:0", - "decoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_11/kernel:0", - "decoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_11/gamma:0", - "decoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_11/beta:0", - "decoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_12/kernel:0", - "decoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_12/gamma:0", - "decoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_12/beta:0", - "decoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_13/kernel:0", - "decoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_13/gamma:0", - "decoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_13/beta:0", - "decoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace("conv3d_14/kernel:0", "final_conv.weight") - .replace("conv3d_14/bias:0", "final_conv.bias") - ) - return k - - -model = get_net() -base_path = os.path.abspath(__file__ + "/..") -weights_path = base_path + "/data/model-weights/trailmap_model.hdf5" -model.load_weights(weights_path) - -for i, l in enumerate(model.layers): - print(i, l) - print( - "L{}: {}".format( - i, ", ".join(str(w.shape) for w in model.layers[i].weights) - ) - ) - -weights_pt = collections.OrderedDict( - [(w.name, torch.from_numpy(w.numpy())) for w in model.trainable_variables] -) -torch.save(weights_pt, base_path + "/data/model-weights/trailmaptorch.pt") -torch_weights = torch.load(base_path + "/data/model-weights/trailmaptorch.pt") -param_dict = { - key_translate(k): weight_translate(k, v) for k, v in torch_weights.items() -} - -trailmap_model = UNet3D(1, 1) -torchparam = trailmap_model.state_dict() -for k, v in torchparam.items(): - print("{:20s} {}".format(k, v.shape)) - -trailmap_model.load_state_dict(param_dict, strict=False) -torch.save( - trailmap_model.state_dict(), - base_path + "/data/model-weights/trailmaptorchpretrained.pt", -) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index e5b189ef..6a73eba0 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1207,7 +1207,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, possible_paths: list = (), - filetype: str = "Image file (*.tif *.tiff)", + file_extension: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1216,14 +1216,14 @@ def open_file_dialog( possible_paths (str): Paths that may have been chosen before, can be a string or an array of strings containing the paths load_as_folder (bool): Whether to open a folder or a single file. If True, will allow opening folder as a single file (2D stack interpreted as 3D) - filetype (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` + file_extension (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` """ default_path = utils.parse_default_path(possible_paths) return QFileDialog.getOpenFileName( - widget, "Choose file", default_path, filetype + widget, "Choose file", default_path, file_extension ) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 90a64cfb..663872c4 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -520,7 +520,7 @@ def parse_default_path(possible_paths): # ] print(default_paths) if len(default_paths) == 0: - return str(Path.home()) + return str(Path().home()) default_path = max(default_paths, key=len) return str(default_path) diff --git a/tox.ini b/tox.ini index ba3e8805..0605fc8c 100644 --- a/tox.ini +++ b/tox.ini @@ -38,5 +38,7 @@ deps = qtpy git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] - +; opencv-python +extras = crf +usedevelop = true commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 102362f670d2b89d02856b686c4cce300ad2696f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:41:07 +0200 Subject: [PATCH 284/577] Run all hooks --- .../_tests/test_plugin_inference.py | 5 ++++- .../code_models/models/model_TRAILMAP.py | 15 +++++--------- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 20 +++++++++++-------- napari_cellseg3d/code_plugins/plugin_base.py | 15 ++++++-------- .../code_plugins/plugin_helper.py | 4 +++- .../code_plugins/plugin_utilities.py | 5 ++++- napari_cellseg3d/dev_scripts/thread_test.py | 6 ++++-- pyproject.toml | 2 +- 9 files changed, 40 insertions(+), 34 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1ae83102..1e486c14 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,9 +34,12 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() + widget.model_choice.setCurrentIndex(-1) + assert widget.window_infer_box.isChecked() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + widget.model_choice.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 8a108e37..e6bbad55 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -39,13 +39,12 @@ def forward(self, x): up8 = self.up8(torch.cat([up7, conv0], 1)) # l1 # print(up8.shape) - out = self.out(up8) + return self.out(up8) # print("out:") # print(out.shape) - return out def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -56,10 +55,9 @@ def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.ReLU(), nn.MaxPool3d(2), ) - return encode def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -69,10 +67,9 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.BatchNorm3d(out_ch), nn.ReLU(), ) - return encode def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - decode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -85,13 +82,11 @@ def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): out_ch, out_ch, kernel_size=kernel_size, stride=(2, 2, 2) ), ) - return decode def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - out = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) - return out class TRAILMAP_(TRAILMAP): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 3416acb1..2900b89c 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -204,4 +204,4 @@ def __init__(self, device, in_channels, out_channels, dropout=0.65): def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) \ No newline at end of file + return self.module(x.to(self.device)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c67ea523..245e6f02 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,28 +199,31 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files + class ONNXModelWrapper(torch.nn.Module): """Class to replace torch model by ONNX Runtime session""" + def __init__(self, file_location): super().__init__() try: - import onnx import onnxruntime as ort except ImportError as e: logger.error("ONNX is not installed but ONNX model was loaded") logger.error(e) msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" logger.error(msg) - raise ImportError(msg) + raise ImportError(msg) from e self.ort_session = ort.InferenceSession( file_location, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) def forward(self, modeL_input): """Wraps ONNX output in a torch tensor""" - outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + outputs = self.ort_session.run( + None, {"input": modeL_input.cpu().numpy()} + ) return torch.tensor(outputs[0]) def eval(self): @@ -231,6 +234,7 @@ def to(self, device): """Dummy function to replace model.to(device)""" pass + @dataclass class InferenceResult: """Class to record results of a segmentation job""" @@ -858,7 +862,7 @@ def inference(self): elif Path(weights_config.path).suffix == ".onnx": self.log("Instantiating ONNX model...") model = ONNXModelWrapper(weights_config.path) - else: # assume is .pth + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], @@ -1606,8 +1610,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1620,7 +1624,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 26da7a42..cfa3f0d7 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -227,17 +227,16 @@ def _show_filetype_choice(self): def _show_file_dialog(self): """Open file dialog and process path depending on single file/folder loading behaviour""" if self.load_as_stack_choice.isChecked(): - folder = ui.open_folder_dialog( + choice = ui.open_folder_dialog( self, self._default_path, filetype=f"Image file (*{self.filetype_choice.currentText()})", ) - return folder else: f_name = ui.open_file_dialog(self, self._default_path) - f_name = str(f_name[0]) - self.filetype = str(Path(f_name).suffix) - return f_name + choice = str(f_name[0]) + self.filetype = str(Path(choice).suffix) + return choice def _show_dialog_images(self): """Show file dialog and set image path""" @@ -291,16 +290,14 @@ def _make_close_button(self): return btn def _make_prev_button(self): - btn = ui.Button( + return ui.Button( "Previous", lambda: self.setCurrentIndex(self.currentIndex() - 1) ) - return btn def _make_next_button(self): - btn = ui.Button( + return ui.Button( "Next", lambda: self.setCurrentIndex(self.currentIndex() + 1) ) - return btn def remove_from_viewer(self): """Removes the widget from the napari window. diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index a3fd8c0d..54c34a8f 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -1,6 +1,8 @@ import pathlib +from typing import TYPE_CHECKING -import napari +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import QSize diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 868dd279..6e1a606a 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -1,4 +1,7 @@ -import napari +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import qInstallMessageHandler diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 20668125..a48f6db0 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -1,8 +1,8 @@ import time import napari -import numpy as np from napari.qt.threading import thread_worker +from numpy.random import PCG64, Generator from qtpy.QtWidgets import ( QGridLayout, QLabel, @@ -13,6 +13,8 @@ QWidget, ) +rand_gen = Generator(PCG64(12345)) + @thread_worker def two_way_communication_with_args(start, end): @@ -129,7 +131,7 @@ def on_finish(): if __name__ == "__main__": - viewer = napari.view_image(np.random.rand(512, 512)) + viewer = napari.view_image(rand_gen.random(512, 512)) w = create_connected_widget(viewer) viewer.window.add_dock_widget(w) diff --git a/pyproject.toml b/pyproject.toml index 2783761e..f71ddb23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ select = [ ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) # and 'G004' (do not use f-strings in logging) -ignore = ["E501", "E741", "G004"] +ignore = ["E501", "E741", "G004", "A003"] exclude = [ ".bzr", ".direnv", From c810fcc060d5a69dbd21c49e47a880ec11d76ee4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 11:27:58 +0200 Subject: [PATCH 285/577] Fix inference testing --- .../_tests/test_plugin_inference.py | 13 +++++++----- .../code_models/models/model_test.py | 20 +++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1e486c14..779f5094 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,12 +34,15 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - widget.model_choice.setCurrentIndex(-1) + widget.model_choice.setCurrentText("WNet") + widget._restrict_window_size_for_model() assert widget.window_infer_box.isChecked() + assert widget.window_size_choice.currentText() == "64" - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(-1) + test_model_name = "test" + MODEL_LIST[test_model_name] = TestModel + widget.model_choice.addItem(test_model_name) + widget.model_choice.setCurrentText(test_model_name) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None @@ -59,6 +62,6 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) - assert res.result.shape == (6, 6, 6) + assert res.result.shape == (8, 8, 8) widget.on_yield(res) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 1cb52f06..28f3a05b 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -20,13 +20,13 @@ def forward(self, x): # return val_inputs -# if __name__ == "__main__": -# -# model = TestModel() -# model.train() -# model.zero_grad() -# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR -# torch.save( -# model.state_dict(), -# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" -# ) +if __name__ == "__main__": + model = TestModel() + model.train() + model.zero_grad() + from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR + + torch.save( + model.state_dict(), + PRETRAINED_WEIGHTS_DIR + f"/{TestModel.weights_file}", + ) From 28eda5c62b76c58a25298011b5638ccd08b9c20a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 13:45:50 +0200 Subject: [PATCH 286/577] Changed anisotropy calculation --- napari_cellseg3d/_tests/test_interface.py | 8 +++++++- napari_cellseg3d/interface.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py index be811721..08e0e675 100644 --- a/napari_cellseg3d/_tests/test_interface.py +++ b/napari_cellseg3d/_tests/test_interface.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.interface import Log +from napari_cellseg3d.interface import AnisotropyWidgets, Log def test_log(qtbot): @@ -12,3 +12,9 @@ def test_log(qtbot): assert log.toPlainText() == "\ntest2" qtbot.add_widget(log) + + +def test_zoom_factor(): + resolution = [10.0, 10.0, 5.0] + zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution) + assert zoom == [1, 1, 0.5] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 6a73eba0..57d78795 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -734,8 +734,8 @@ def anisotropy_zoom_factor(aniso_res): """ - base = min(aniso_res) - return [base / res for res in aniso_res] + base = max(aniso_res) + return [res / base for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" From b453eeaceaefb34c0a8f55dd2f9122a6e2e90be3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 11:37:13 +0200 Subject: [PATCH 287/577] Finish rebase + bump version --- napari_cellseg3d/__init__.py | 2 +- .../code_models/instance_segmentation.py | 105 +++++++----------- .../code_plugins/plugin_helper.py | 2 +- .../dev_scripts/evaluate_labels.py | 81 +++++++++++++- pyproject.toml | 3 +- setup.cfg | 2 +- 6 files changed, 119 insertions(+), 76 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 6e2681e8..be8123e4 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc1" +__version__ = "0.0.3rc1" diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 93de0768..f5066ebe 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -9,9 +9,6 @@ from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui @@ -110,42 +107,6 @@ def run_method_on_channels(self, image): return result.squeeze() -class InstanceMethod: - def __init__( - self, - name: str, - function: callable, - num_sliders: int, - num_counters: int, - ): - self.name = name - self.function = function - self.counters: List[ui.DoubleIncrementCounter] = [] - self.sliders: List[ui.Slider] = [] - if num_sliders > 0: - for i in range(num_sliders): - widget = f"slider_{i}" - setattr( - self, - widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), - ) - self.sliders.append(getattr(self, widget)) - - if num_counters > 0: - for i in range(num_counters): - widget = f"counter_{i}" - setattr( - self, - widget, - ui.DoubleIncrementCounter(label=""), - ) - self.counters.append(getattr(self, widget)) - - def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") - - @dataclass class ImageStats: volume: List[float] @@ -186,7 +147,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - remove_small_size: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -202,12 +163,13 @@ def voronoi_otsu( Instance segmentation labels from Voronoi-Otsu method """ - semantic = np.squeeze(volume) + # remove_small_size (float): remove all objects smaller than the specified size in pixels + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) instance = cle.voronoi_otsu_labeling( - semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) return np.array(instance) @@ -225,8 +187,6 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) - """ logger.debug( f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" @@ -445,13 +405,16 @@ def sphericity(region): ) -class Watershed(InstanceMethod, metaclass=Singleton): - def __init__(self): +class Watershed(InstanceMethod): + """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" + + def __init__(self, widget_parent=None): super().__init__( - name="Watershed", + name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -488,13 +451,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): - def __init__(self): +class ConnectedComponents(InstanceMethod): + """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" + + def __init__(self, widget_parent=None): super().__init__( - name="Connected Components", + name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -516,33 +482,37 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): - def __init__(self): +class VoronoiOtsu(InstanceMethod): + """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" + + def __init__(self, widget_parent=None): super().__init__( - name="Voronoi-Otsu", + name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, - num_counters=3, + num_counters=2, + widget_parent=widget_parent, ) - self.counters[0].label.setText("Spot sigma") + self.counters[0].label.setText("Spot sigma") # closeness self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" self.counters[1].setMaximum(100) self.counters[1].setValue(2) - self.counters[2].label.setText("Small object removal") - self.counters[2].tooltips = ( - "Volume/size threshold for small object removal." - "\nAll objects with a volume/size below this value will be removed." - ) + # self.counters[2].label.setText("Small object removal") + # self.counters[2].tooltips = ( + # "Volume/size threshold for small object removal." + # "\nAll objects with a volume/size below this value will be removed." + # ) + # self.counters[2].setValue(30) def run_method(self, image): ################ @@ -557,7 +527,7 @@ def run_method(self, image): image, self.counters[0].value(), self.counters[1].value(), - self.counters[2].value(), + # self.counters[2].value(), ) @@ -575,7 +545,6 @@ def __init__(self, parent=None): """ super().__init__(parent) - self.method_choice = ui.DropdownMenu( list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) @@ -588,7 +557,6 @@ def __init__(self, parent=None): self._build() def _build(self): - group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) @@ -620,6 +588,9 @@ def _set_visibility(self): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) + else: + for widget in self.instance_widgets[name]: + widget.set_visibility(True) def run_method(self, volume): """ @@ -636,7 +607,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, - VoronoiOtsu().name: VoronoiOtsu, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 54c34a8f..552f70ea 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -39,7 +39,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.3rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 00bce5ec..64fbaf5e 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,5 @@ import napari import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm @@ -128,9 +126,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: @@ -190,6 +186,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file diff --git a/pyproject.toml b/pyproject.toml index f71ddb23..e39a7522 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "napari_cellseg3d" -version = "0.0.2rc6" +version = "0.0.3rc1" authors = [ {name = "Cyril Achard", email = "cyril.achard@epfl.ch"}, {name = "Maxime Vidal", email = "maxime.vidal@epfl.ch"}, @@ -102,6 +102,7 @@ dev = [ "black", "ruff", "pre-commit", + "tuna", ] docs = [ "sphinx", diff --git a/setup.cfg b/setup.cfg index f3294b60..8ee82f96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc6 +version = 0.0.3rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 53d30020ff482f46012ddcca115e43a7abc5d315 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 12:12:28 +0200 Subject: [PATCH 288/577] Fixed aniso correction and CRF interaction --- napari_cellseg3d/code_models/workers.py | 17 ++++++++++++++--- napari_cellseg3d/interface.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 245e6f02..50f85395 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -724,7 +724,12 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): instance_labels, stats = self.get_instance_result(out, i=i) if self.config.use_crf: try: - crf_results = self.run_crf(inputs, out, image_id=i) + crf_results = self.run_crf( + inputs, + out, + aniso_transform=self.aniso_transform, + image_id=i, + ) except ValueError as e: self.log(f"Error occurred during CRF : {e}") @@ -746,8 +751,10 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): i=i, ) - def run_crf(self, image, labels, image_id=0): + def run_crf(self, image, labels, aniso_transform, image_id=0): try: + if aniso_transform is not None: + image = aniso_transform(image) crf_results = crf_with_config( image, labels, config=self.config.crf_config, log=self.log ) @@ -795,7 +802,11 @@ def inference_on_layer(self, image, model, post_process_transforms): semantic_labels=out, from_layer=True ) - crf_results = self.run_crf(image, out) if self.config.use_crf else None + crf_results = ( + self.run_crf(image, out, aniso_transform=self.aniso_transform) + if self.config.use_crf + else None + ) return self.create_inference_result( semantic_labels=out, diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d2ec5789..014c17b6 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -667,7 +667,7 @@ def __init__( w.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self.box_widgets_lbl = [ - make_label("Resolution in " + axis + " (microns) :", parent=parent) + make_label("Pixel size in " + axis + " (microns) :", parent=parent) for axis in "xyz" ] From 1063f7f6ae5c67b425b61bc149bf570b33ee91dd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 12:20:04 +0200 Subject: [PATCH 289/577] Remove duplicate tests --- .github/workflows/test_and_deploy.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index e9a66ae2..105c260a 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -7,17 +7,12 @@ on: push: branches: - main - - npe2 - - cy/voronoi-otsu - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - - npe2 - - cy/voronoi-otsu - - cy/wnet workflow_dispatch: jobs: From ade9d5a7821b59fad5ee31819be34b6807a42431 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 290/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- .../_tests/test_plugin_inference.py | 1 + .../code_models/model_instance_seg.py | 193 ++++++++++-------- napari_cellseg3d/code_models/model_workers.py | 11 +- .../code_plugins/plugin_convert.py | 29 +-- .../code_plugins/plugin_model_inference.py | 20 +- napari_cellseg3d/config.py | 15 +- napari_cellseg3d/interface.py | 58 +++--- requirements.txt | 4 +- 8 files changed, 170 insertions(+), 161 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..584ffd3b 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,6 +8,7 @@ from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 60f8bbda..4d0d5c78 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -2,19 +2,26 @@ from typing import List import numpy as np + import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label, regionprops + +from skimage.measure import label +from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed +from skimage.filters import thresholding +from skimage.transform import resize + # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between +from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from napari_cellseg3d.utils import sphericity_volume_area @@ -79,6 +86,42 @@ def run_method(self, image): raise NotImplementedError("Must be defined in child classes") +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -119,32 +162,27 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - # remove_small_size: float, + remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant - Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation + remove_small_size (float): remove all objects smaller than the specified size in pixels Returns: Instance segmentation labels from Voronoi-Otsu method - """ - # remove_small_size (float): remove all objects smaller than the specified size in pixels - # semantic = np.squeeze(volume) - logger.debug( - f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" - ) + semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( - volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma + semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return np.array(instance) + return instance def binary_connected( @@ -159,8 +197,6 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) - """ logger.debug( f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" @@ -380,16 +416,13 @@ def fill(lst, n=len(properties) - 1): ) -class Watershed(InstanceMethod): - """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - - def __init__(self, widget_parent=None): +class Watershed(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=WATERSHED, + name="Watershed", function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -419,23 +452,20 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( image, - self.sliders[0].slider_value, - self.sliders[1].slider_value, + self.sliders[0].value(), + self.sliders[1].value(), self.counters[0].value(), self.counters[1].value(), ) -class ConnectedComponents(InstanceMethod): - """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - - def __init__(self, widget_parent=None): +class ConnectedComponents(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=CONNECTED_COMP, + name="Connected Components", function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -453,56 +483,44 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( - image, self.sliders[0].slider_value, self.counters[0].value() + image, self.sliders[0].value(), self.counters[0].value() ) -class VoronoiOtsu(InstanceMethod): - """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - - def __init__(self, widget_parent=None): +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=VORONOI_OTSU, + name="Voronoi-Otsu", function=voronoi_otsu, num_sliders=0, - num_counters=2, - widget_parent=widget_parent, + num_counters=3, ) - self.counters[0].label.setText("Spot sigma") # closeness + self.counters[0].label.setText("Spot sigma") self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") # smoothness + self.counters[1].label.setText("Outline sigma") self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" self.counters[1].setMaximum(100) self.counters[1].setValue(2) - # self.counters[2].label.setText("Small object removal") - # self.counters[2].tooltips = ( - # "Volume/size threshold for small object removal." - # "\nAll objects with a volume/size below this value will be removed." - # ) - # self.counters[2].setValue(30) + self.counters[2].label.setText("Small object removal") + self.counters[2].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) def run_method(self, image): - ################ - # For debugging - # import napari - # view = napari.Viewer() - # view.add_image(image) - # napari.run() - ################ - return self.function( image, self.counters[0].value(), self.counters[1].value(), - # self.counters[2].value(), + self.counters[2].value(), ) @@ -517,72 +535,67 @@ def __init__(self, parent=None): Args: parent: parent widget - """ super().__init__(parent) + self.method_choice = ui.DropdownMenu( - list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) + INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = {} - """Contains the instance of the method, with its name as key""" + self.methods = [] self.instance_widgets = {} - """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() def _build(self): + group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - try: - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - method_class = method(widget_parent=self.parent()) - self.methods[name] = method_class - self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets ? - if len(method_class.sliders) > 0: - for slider in method_class.sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method_class.counters) > 0: - for counter in method_class.counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug( - f"Caught runtime error {e}, most likely during testing" - ) + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + self.instance_widgets[name] = [] + if len(method().sliders) > 0: + for slider in method().sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method().counters) > 0: + for counter in method().counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets: - if name != self.method_choice.currentText(): - for widget in self.instance_widgets[name]: + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + + for widget in self.instance_widgets[method.name]: + widget.set_visibility(True) + + for key in self.instance_widgets.keys(): + if key != method.name: + for widget in self.instance_widgets[key]: widget.set_visibility(False) - else: - for widget in self.instance_widgets[name]: - widget.set_visibility(True) def run_method(self, volume): """ Calls instance function with chosen parameters - Args: volume: image data to run method on Returns: processed image from self._method - """ - method = self.methods[self.method_choice.currentText()] + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() return method.run_method(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { - VORONOI_OTSU: VoronoiOtsu, - WATERSHED: Watershed, - CONNECTED_COMP: ConnectedComponents, + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, } diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 30d37bbd..06eff6e5 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -51,10 +51,9 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_instance_seg import ( - ImageStats, - volume_stats, -) +from napari_cellseg3d import utils +from napari_cellseg3d.code_models.model_instance_seg import ImageStats +from napari_cellseg3d.code_models.model_instance_seg import volume_stats logger = utils.LOGGER @@ -600,8 +599,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6c8370c1..b7064567 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -14,6 +14,9 @@ threshold, to_semantic, ) +from napari_cellseg3d.code_models.model_instance_seg import threshold +from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -345,20 +348,18 @@ def _start(self): show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) - + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): """ diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 22867343..158203af 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,13 +9,12 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import ( - InstanceMethod, - InstanceWidgets, -) -from napari_cellseg3d.code_models.model_workers import ( - InferenceResult, - InferenceWorker, + INSTANCE_SEGMENTATION_METHOD_LIST, ) @@ -556,12 +555,9 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[ - self.instance_widgets.method_choice.currentText() - ], - ) + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] self.post_process_config = config.PostProcessConfig( zoom=zoom_config, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 737b53aa..c9e12f06 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,7 +7,6 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -16,6 +15,12 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -113,17 +118,11 @@ class Zoom: zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: InstanceMethod = None - - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 276f9214..397d4e48 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -7,31 +7,33 @@ # Qt # from qtpy.QtCore import QtWarningMsg -from qtpy import QtCore -from qtpy.QtCore import QObject, Qt, QUrl -from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor -from qtpy.QtWidgets import ( - QCheckBox, - QComboBox, - QDoubleSpinBox, - QFileDialog, - QGridLayout, - QGroupBox, - QHBoxLayout, - QLabel, - QLayout, - QLineEdit, - QMenu, - QPushButton, - QRadioButton, - QScrollArea, - QSizePolicy, - QSlider, - QSpinBox, - QTextEdit, - QVBoxLayout, - QWidget, -) +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt +# from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QUrl +from qtpy.QtGui import QCursor +from qtpy.QtGui import QDesktopServices +from qtpy.QtGui import QTextCursor +from qtpy.QtWidgets import QCheckBox +from qtpy.QtWidgets import QComboBox +from qtpy.QtWidgets import QDoubleSpinBox +from qtpy.QtWidgets import QFileDialog +from qtpy.QtWidgets import QGridLayout +from qtpy.QtWidgets import QGroupBox +from qtpy.QtWidgets import QHBoxLayout +from qtpy.QtWidgets import QLabel +from qtpy.QtWidgets import QLayout +from qtpy.QtWidgets import QLineEdit +from qtpy.QtWidgets import QMenu +from qtpy.QtWidgets import QPushButton +from qtpy.QtWidgets import QRadioButton +from qtpy.QtWidgets import QScrollArea +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QSlider +from qtpy.QtWidgets import QSpinBox +from qtpy.QtWidgets import QTextEdit +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # Local from napari_cellseg3d import utils @@ -1046,11 +1048,11 @@ def __init__( self.label = make_label(name=label) self.valueChanged.connect(self._update_step) - def _update_step(self): # FIXME check divide_factor + def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.01) - else: self.setSingleStep(0.1) + else: + self.setSingleStep(1) @property def tooltips(self): diff --git a/requirements.txt b/requirements.txt index 3189e9c4..9c7126eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pre-commit -pyclesperanto-prototype>=0.22.0 -pysqlite3 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From bc89d2035c1e12107cc7c5e9e2c2900e06b783f0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 291/577] Disabled small removal in Voronoi-Otsu --- .../code_models/model_instance_seg.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 4d0d5c78..85be5bde 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -162,7 +162,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - remove_small_size: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -172,11 +172,12 @@ def voronoi_otsu( volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - remove_small_size (float): remove all objects smaller than the specified size in pixels + Returns: Instance segmentation labels from Voronoi-Otsu method """ + # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma @@ -493,7 +494,7 @@ def __init__(self): name="Voronoi-Otsu", function=voronoi_otsu, num_sliders=0, - num_counters=3, + num_counters=2, ) self.counters[0].label.setText("Spot sigma") self.counters[ @@ -509,18 +510,19 @@ def __init__(self): self.counters[1].setMaximum(100) self.counters[1].setValue(2) - self.counters[2].label.setText("Small object removal") - self.counters[2].tooltips = ( - "Volume/size threshold for small object removal." - "\nAll objects with a volume/size below this value will be removed." - ) + # self.counters[2].label.setText("Small object removal") + # self.counters[2].tooltips = ( + # "Volume/size threshold for small object removal." + # "\nAll objects with a volume/size below this value will be removed." + # ) + # self.counters[2].setValue(30) def run_method(self, image): return self.function( image, self.counters[0].value(), self.counters[1].value(), - self.counters[2].value(), + # self.counters[2].value(), ) From 2396f9df10d85ca3e9a5ed245ef64067f25b3d31 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 292/577] Added new docs for instance seg --- .../code_models/model_instance_seg.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 85be5bde..3a6136a9 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -94,6 +94,14 @@ def __init__( num_sliders: int, num_counters: int, ): + """ + Methods for instance segmentation + Args: + name: Name of the instance segmentation method (for UI) + function: Function to use for instance segmentation + num_sliders: Number of Slider UI elements needed to set the parameters of the function + num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + """ self.name = name self.function = function self.counters: List[ui.DoubleIncrementCounter] = [] @@ -418,6 +426,8 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod, metaclass=Singleton): + """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" + def __init__(self): super().__init__( name="Watershed", @@ -461,6 +471,8 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod, metaclass=Singleton): + """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" + def __init__(self): super().__init__( name="Connected Components", @@ -489,6 +501,8 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" + def __init__(self): super().__init__( name="Voronoi-Otsu", @@ -496,14 +510,14 @@ def __init__(self): num_sliders=0, num_counters=2, ) - self.counters[0].label.setText("Spot sigma") + self.counters[0].label.setText("Spot sigma") # closeness self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -597,7 +611,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { + VoronoiOtsu().name: VoronoiOtsu, Watershed().name: Watershed, ConnectedComponents().name: ConnectedComponents, - VoronoiOtsu().name: VoronoiOtsu, } From ea900abd7633e967728b3ecfa0071eb2b67e2f73 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 09:50:45 +0100 Subject: [PATCH 293/577] Docs + UI update - Updated welcome/README - Changed step for DoubleCounter --- README.md | 2 +- docs/res/welcome.rst | 4 +++- napari_cellseg3d/interface.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ece6c6f4..ca8d0931 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 892549a8..12a20630 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -103,6 +103,8 @@ This plugin mainly uses the following libraries and software: * `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase +* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase + * A custom re-implementation of the `WNet model`_ by Xia and Kulis [#]_ .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ @@ -113,7 +115,7 @@ This plugin mainly uses the following libraries and software: .. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets .. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype -.. _WNet model: https://arxiv.org/abs/1711.08506 + .. rubric:: References diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 397d4e48..11b2490e 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1050,9 +1050,9 @@ def __init__( def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.1) + self.setSingleStep(0.01) else: - self.setSingleStep(1) + self.setSingleStep(0.1) @property def tooltips(self): From 0e06886cf91de914d2ac11d2a08fbb5630ed2780 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:07:33 +0100 Subject: [PATCH 294/577] Update requirements.txt Fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9c7126eb..834a225e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pyclesperanto-prototype>=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 90fb697e703582213fd1af10fac8518d86982c31 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 295/577] isort --- napari_cellseg3d/code_models/model_instance_seg.py | 13 ++++--------- napari_cellseg3d/code_plugins/plugin_convert.py | 13 ++++--------- .../code_plugins/plugin_model_inference.py | 8 ++++---- napari_cellseg3d/config.py | 11 +++++------ napari_cellseg3d/interface.py | 3 ++- 5 files changed, 19 insertions(+), 29 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 3a6136a9..b4fc5c32 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,27 +1,22 @@ +from __future__ import division +from __future__ import print_function from dataclasses import dataclass from typing import List - import numpy as np - import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget - from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - -from skimage.filters import thresholding -from skimage.transform import resize - +from tifffile import imread # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes -from tifffile import imread from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index b7064567..a04d9f09 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,22 +1,17 @@ import warnings from pathlib import Path - import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread +from tifffile import imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - InstanceWidgets, - clear_small_objects, - threshold, - to_semantic, -) +from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 158203af..1da36989 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,13 +9,13 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import ( INSTANCE_SEGMENTATION_METHOD_LIST, ) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index c9e12f06..3e92b88e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,6 +7,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -15,12 +20,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 11b2490e..d697245a 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -9,7 +9,8 @@ # from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QObject from qtpy.QtCore import Qt -# from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices From c0b1cdc49f72b83e534258eb3bdab95aadf458bc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 296/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 45c3be1c..814cca2e 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 +qt_api=pyqt5 \ No newline at end of file From ccbf7afa5b39338fabb6d07dda59a7932f28f28b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 297/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- .../code_models/model_instance_seg.py | 130 ++++++------------ 1 file changed, 40 insertions(+), 90 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index b4fc5c32..253005b3 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,8 +15,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from napari_cellseg3d.utils import sphericity_volume_area @@ -33,18 +33,16 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None, + widget_parent: QWidget = None ): """ Methods for instance segmentation - Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets - """ self.name = name self.function = function @@ -56,14 +54,7 @@ def __init__( setattr( self, widget, - ui.Slider( - 0, - 100, - 1, - divide_factor=100, - text_label="", - parent=None, - ), + ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), ) self.sliders.append(getattr(self, widget)) @@ -81,50 +72,6 @@ def run_method(self, image): raise NotImplementedError("Must be defined in child classes") -class InstanceMethod: - def __init__( - self, - name: str, - function: callable, - num_sliders: int, - num_counters: int, - ): - """ - Methods for instance segmentation - Args: - name: Name of the instance segmentation method (for UI) - function: Function to use for instance segmentation - num_sliders: Number of Slider UI elements needed to set the parameters of the function - num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function - """ - self.name = name - self.function = function - self.counters: List[ui.DoubleIncrementCounter] = [] - self.sliders: List[ui.Slider] = [] - if num_sliders > 0: - for i in range(num_sliders): - widget = f"slider_{i}" - setattr( - self, - widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), - ) - self.sliders.append(getattr(self, widget)) - - if num_counters > 0: - for i in range(num_counters): - widget = f"counter_{i}" - setattr( - self, - widget, - ui.DoubleIncrementCounter(label=""), - ) - self.counters.append(getattr(self, widget)) - - def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") - - @dataclass class ImageStats: volume: List[float] @@ -420,15 +367,16 @@ def fill(lst, n=len(properties) - 1): ) -class Watershed(InstanceMethod, metaclass=Singleton): +class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Watershed", + name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -465,15 +413,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): +class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( - name="Connected Components", + name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, + widget_parent=widget_parent ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -495,15 +444,16 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): +class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self): + def __init__(self, widget_parent): super().__init__( - name="Voronoi-Otsu", + name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, num_counters=2, + widget_parent=widget_parent ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ @@ -548,7 +498,6 @@ def __init__(self, parent=None): parent: parent widget """ super().__init__(parent) - self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) @@ -559,37 +508,38 @@ def __init__(self, parent=None): self._build() def _build(self): - group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - self.instance_widgets[name] = [] - if len(method().sliders) > 0: - for slider in method().sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method().counters) > 0: - for counter in method().counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) + try: + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + method_class = method(widget_parent=self.parent()) + self.instance_widgets[name] = [] + # moderately unsafe way to init those widgets + if len(method_class.sliders) > 0: + for slider in method_class.sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method_class.counters) > 0: + for counter in method_class.counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + except RuntimeError as e: + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() - - for widget in self.instance_widgets[method.name]: - widget.set_visibility(True) - for key in self.instance_widgets.keys(): - if key != method.name: - for widget in self.instance_widgets[key]: + for name in self.instance_widgets.keys(): + if name != self.method_choice.currentText(): + for widget in self.instance_widgets[name]: widget.set_visibility(False) + else: + for widget in self.instance_widgets[name]: + widget.set_visibility(True) def run_method(self, volume): """ @@ -606,7 +556,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VoronoiOtsu().name: VoronoiOtsu, - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } From dfcf39f4025760751ef656fd1a045fcc31627937 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 298/577] Fix inference --- .../code_models/model_instance_seg.py | 5 +- napari_cellseg3d/code_models/model_workers.py | 8 +- .../code_plugins/plugin_model_inference.py | 11 +- napari_cellseg3d/config.py | 6 +- notebooks/assess_instance.ipynb | 479 +----------------- 5 files changed, 30 insertions(+), 479 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 253005b3..13acae56 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -501,8 +501,10 @@ def __init__(self, parent=None): self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = [] + self.methods = {} + """Contains the instance of the method, with its name as key""" self.instance_widgets = {} + """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() @@ -514,6 +516,7 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class self.instance_widgets[name] = [] # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 06eff6e5..367be2f0 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -544,9 +544,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes( - instance_labels, 0, 2 - ) # TODO(cyril) check if correct + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -599,8 +597,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 1da36989..ff173b43 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -555,9 +555,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -725,9 +726,7 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + method_name = self.worker_config.post_process_config.instance.method.name number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 3e92b88e..d1d8ab88 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -116,12 +116,16 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b8810301..40412282 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,500 +4,47 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "tags": [] + "collapsed": true }, "outputs": [], "source": [ - "import napari\n", "import numpy as np\n", - "from pathlib import Path\n", "from tifffile import imread\n", - "\n", - "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", - "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import (\n", - " binary_connected,\n", - " binary_watershed,\n", - " voronoi_otsu,\n", - " to_semantic,\n", - ")" + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [] - }, + "execution_count": null, "outputs": [], - "source": [ - "viewer = napari.Viewer()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", - "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", - "\n", - "prediction = imread(prediction_path)\n", - "gt_labels = imread(gt_labels_path)\n", - "\n", - "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", - "gt_labels_resized = resize(gt_labels, zoom)\n", - "\n", - "\n", - "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "0.5817600487210719" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from napari_cellseg3d.utils import dice_coeff\n", - "\n", - "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "source": [], "metadata": { "collapsed": false, - "jupyter": { - "outputs_hidden": false + "pycharm": { + "name": "#%%\n" } - }, - "outputs": [], - "source": [ - "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", - "\n", - "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n", - "125\n" - ] - } - ], - "source": [ - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)\n", - "print(np.unique(gt_labels_resized).shape[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "connected = binary_connected(prediction_resized, thres_small=2)\n", - "viewer.add_labels(connected, name=\"connected\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, connected)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(68,\n", - " 43,\n", - " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "watershed = binary_watershed(\n", - " prediction_resized, thres_small=2, rem_seed_thres=1\n", - ")\n", - "viewer.add_labels(watershed)\n", - "eval.evaluate_model_performance(gt_labels_resized, watershed)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(25, 64, 64)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", - "\n", - "from skimage.morphology import remove_small_objects\n", - "\n", - "voronoi = remove_small_objects(voronoi, 2)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "dtype('int64')" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gt_labels_resized.dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# np.unique(voronoi, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# np.unique(gt_labels_resized, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] + } } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" + "pygments_lexer": "ipython2", + "version": "2.7.6" } }, "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat_minor": 0 +} \ No newline at end of file From 150889a1dad4c6564d3346ea9102b82f83800d63 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 299/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/model_instance_seg.py | 8 +- .../dev_scripts/artefact_labeling.py | 136 ++-- .../dev_scripts/correct_labels.py | 129 ++-- .../dev_scripts/evaluate_labels.py | 595 +++--------------- notebooks/assess_instance.ipynb | 401 +++++++++++- 5 files changed, 588 insertions(+), 681 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 13acae56..54c453d2 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -37,12 +37,14 @@ def __init__( ): """ Methods for instance segmentation + Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -118,14 +120,15 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - Returns: Instance segmentation labels from Voronoi-Otsu method + """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) @@ -496,6 +499,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -547,10 +551,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = INSTANCE_SEGMENTATION_METHOD_LIST[ self.method_choice.currentText() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 3f95e1a8..875ca9b6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,16 +1,16 @@ -import os - -import napari import numpy as np +from tifffile import imread +from tifffile import imwrite +from pathlib import Path import scipy.ndimage as ndimage -from skimage.filters import threshold_otsu -from tifffile import imread, imwrite - -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed - +import os +import napari # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from skimage.filters import threshold_otsu + """ New code by Yves Paychere Creates labels of artifacts in an image based on existing labels of neurons @@ -44,9 +44,7 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append( - np.array([i, unique[np.argmax(counts)]]) - ) + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -63,7 +61,7 @@ def map_labels(labels, artefacts): def make_labels( - image, + path_image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -75,8 +73,8 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - image : str - image array + path_image : str + Path to image. path_labels_out : str Path of the output labelled image. threshold_size : int, optional @@ -95,26 +93,21 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(image) + image = imread(path_image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( - np.max(image_contrasted) - np.min(image_contrasted) - ) - + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size( - labels, min_size=threshold_size, is_labeled=True - ) + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -126,27 +119,26 @@ def make_labels( ) -def select_image_by_labels(image, labels, path_image_out, label_values): +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - image : np.array - image. - labels : np.array - labels. + path_image : str + Path to image. + path_labels : str + Path to labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - # image = imread(image) - # labels = imread(labels) - + image = imread(path_image) + labels = imread(path_labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smallest cube that contains all the non-zero pixels of a 3d image +# select the smalles cube that contains all the none zero pixel of an 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -164,15 +156,16 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(image, path_image_out): +def crop_image_path(path_image, path_image_out): """Crop image. Parameters ---------- - image : np.array - image + path_image : str + Path to image. path_image_out : str Path of the output image. """ + image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -220,9 +213,7 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile( - image[neurons], threshold_artefact_brightness_percent - ) + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -253,9 +244,7 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile( - sizes, threshold_artefact_size_percent - ) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -305,8 +294,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image, - labels, + image_path, + labels_path, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -315,10 +304,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image : np.array - image for artefact detection. - labels : np.array - label image array with each neurons labelled as a different int value. + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -328,6 +317,9 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ + image = imread(image_path) + labels = imread(labels_path) + artefacts = make_artefact_labels( image, labels, @@ -347,12 +339,11 @@ def visualize_images(paths): Parameters ---------- paths : list - List of images to visualize. + List of paths to images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - image = imread(path) - viewer.add_image(image) + viewer.add_image(imread(path), name=os.path.basename(path)) # wait for the user to close the viewer napari.run() @@ -379,12 +370,8 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [ - f for f in os.listdir(path + "/labels") if f.endswith(".tif") - ] - path_images = [ - f for f in os.listdir(path + "/volumes") if f.endswith(".tif") - ] + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] # sort the list path_labels.sort() path_images.sort() @@ -412,22 +399,23 @@ def create_artefact_labels_from_folder( ) -# if __name__ == "__main__": -# repo_path = Path(__file__).resolve().parents[1] -# print(f"REPO PATH : {repo_path}") -# paths = [ -# "dataset_clean/cropped_visual/train", -# "dataset_clean/cropped_visual/val", -# "dataset_clean/somatomotor", -# "dataset_clean/visual_tif", -# ] -# for data_path in paths: -# path = str(repo_path / data_path) -# print(path) -# create_artefact_labels_from_folder( -# path, -# do_visualize=False, -# threshold_artefact_brightness_percent=20, -# threshold_artefact_size_percent=1, -# contrast_power=20, -# ) +if __name__ == "__main__": + + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 168990e1..f94327e2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,19 @@ -import threading -import time -import warnings -from functools import partial -from pathlib import Path - -import napari import numpy as np +from tifffile import imread +from tifffile import imwrite import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings from napari.qt.threading import thread_worker -from tifffile import imread, imwrite from tqdm import tqdm - -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed - +import threading # import sys # sys.path.append(str(Path(__file__) / "../../")) + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -35,9 +33,7 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm( - range(len(unique_label)), desc="relabeling", ncols=100 - ): + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): i = unique_label[i_label] if i == 0: continue @@ -85,16 +81,13 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact, test=False): +def ask_labels(unique_artefact): global returns returns = [] - if not test: - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] - else: - i_labels_to_add_tmp = [0] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] if i_labels_to_add_tmp == [0]: print("no label added") @@ -137,15 +130,7 @@ def ask_labels(unique_artefact, test=False): print("close the napari window to continue") -def relabel( - image_path, - label_path, - go_fast=False, - check_for_unicity=True, - delay=0.3, - viewer=None, - test=False, -): +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -159,8 +144,6 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 - viewer : napari.Viewer, optional - the napari viewer, by default None """ global returns @@ -175,10 +158,7 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - if not test: - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -198,49 +178,30 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread( - target=partial(ask_labels, test=test), args=(unique_artefact,) - ) + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add), 0, artefact - ) - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer - viewer.add_image(image, name="image") + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - if not test: - napari.run() + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add_tmp), artefact, 0 - ) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) print("these labels will be added") - if test: - viewer.close() - viewer = napari.view_image(image) if viewer is None else viewer - if not test: - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") - if test: - revert = "n" - viewer.close() + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) - if test: - break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -297,16 +258,12 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget( - old_label, new_label, map_labels_existing, delay=0.5 -): +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect( - lambda arg: modify_viewer(old_label, new_label, arg) - ) + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -323,12 +280,8 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] - ) - new_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] - ) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -337,9 +290,7 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget( - old_label, new_label, map_labels_existing, delay=delay - ) + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) napari.run() @@ -356,14 +307,14 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, - str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") -# -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +if __name__ == "__main__": + + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index bd2f0768..857bcd19 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,20 +1,74 @@ -import napari import numpy as np import pandas as pd from tqdm import tqdm +import napari from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance( - labels, - model_labels, - threshold_correct=PERCENT_CORRECT, - print_details=False, - visualize=False, -): + +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): """Evaluate the model performance. Parameters ---------- @@ -22,10 +76,8 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - print_details : bool + do_print : bool If True, print the results. - visualize : bool - If True, visualize the results. Returns ------- neuron_found : float @@ -49,7 +101,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels, threshold_correct + labels, model_labels ) # calculate the number of neurons individually found @@ -67,9 +119,7 @@ def evaluate_model_performance( artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean( - [i[3] for i in map_labels_existing] - ) + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -78,9 +128,7 @@ def evaluate_model_performance( if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean( - [i[2] for i in map_fused_neurons] - ) + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -95,37 +143,27 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info( - f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" - ) - - if print_details: - log.info(f"Neurons found: {neurons_found}") - log.info(f"Neurons fused: {neurons_fused}") - log.info(f"Neurons not found: {neurons_not_found}") - log.info(f"Artefacts found: {artefacts_found}") - log.info( - f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" - ) - log.info( - f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( + "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, ) - log.info( - f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" + print( + "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, ) - log.info( - f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" + print( + "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, ) - log.info( - f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact ) - if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -141,21 +179,15 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) is False, - unique_labels, - 0, + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where( - np.isin(labels, neurones_not_found_labels), labels, 0 - ) + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), - model_labels, - 0, + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -183,81 +215,6 @@ def evaluate_model_performance( ) -def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > threshold_correct: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > threshold_correct: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels - - def save_as_csv(results, path): """ Save the results as a csv file @@ -269,7 +226,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - log.debug(np.array(results).shape) + print(np.array(results).shape) df = pd.DataFrame( [results], columns=[ @@ -287,380 +244,6 @@ def save_as_csv(results, path): df.to_csv(path, index=False) -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From 7b080b6eb436b024cc1de1e5cb51509021561ec8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 300/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../dev_scripts/evaluate_labels.py | 22 +- notebooks/assess_instance.ipynb | 408 ++++++++++++------ 2 files changed, 301 insertions(+), 129 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..b4436ccb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -4,6 +4,7 @@ import napari from napari_cellseg3d.utils import LOGGER as log + def map_labels(labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -33,10 +34,12 @@ def map_labels(labels, model_labels): unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + print(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + print(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -50,8 +53,7 @@ def map_labels(labels, model_labels): tmp_map.append( [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,12 +61,14 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # print(f"map_labels_existing: {map_labels_existing}") + print(f"map_fused_neurons: {map_fused_neurons}") + # print(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels @@ -99,7 +103,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -109,7 +113,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..6e6a9b5f 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -111,17 +111,274 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "i: 1\n", + "unique: 1\n", + "i: 2\n", + "unique: 2\n", + "i: 3\n", + "unique: 3\n", + "i: 4\n", + "unique: 4\n", + "i: 5\n", + "unique: 5\n", + "i: 6\n", + "unique: 6\n", + "i: 7\n", + "unique: 7\n", + "i: 8\n", + "unique: 8\n", + "i: 9\n", + "unique: 9\n", + "i: 10\n", + "unique: 10\n", + "i: 11\n", + "unique: 11\n", + "i: 12\n", + "unique: 12\n", + "i: 13\n", + "unique: 13\n", + "i: 14\n", + "unique: 14\n", + "i: 15\n", + "unique: 15\n", + "i: 16\n", + "unique: 16\n", + "i: 17\n", + "unique: 17\n", + "i: 18\n", + "unique: 18\n", + "i: 19\n", + "unique: 19\n", + "i: 20\n", + "unique: 20\n", + "i: 21\n", + "unique: 21\n", + "i: 22\n", + "unique: 22\n", + "i: 23\n", + "unique: 23\n", + "i: 24\n", + "unique: 24\n", + "i: 25\n", + "unique: 25\n", + "i: 26\n", + "unique: 26\n", + "i: 27\n", + "unique: 27\n", + "i: 28\n", + "unique: 28\n", + "i: 29\n", + "unique: 29\n", + "i: 30\n", + "unique: 30\n", + "i: 31\n", + "unique: 31\n", + "i: 32\n", + "unique: 32\n", + "i: 33\n", + "unique: 33\n", + "i: 34\n", + "unique: 34\n", + "i: 35\n", + "unique: 35\n", + "i: 36\n", + "unique: 36\n", + "i: 37\n", + "unique: 37\n", + "i: 38\n", + "unique: 38\n", + "i: 39\n", + "unique: 39\n", + "i: 40\n", + "unique: 40\n", + "i: 41\n", + "unique: 41\n", + "i: 42\n", + "unique: 42\n", + "i: 43\n", + "unique: 43\n", + "i: 44\n", + "unique: 44\n", + "i: 45\n", + "unique: 45\n", + "i: 46\n", + "unique: 46\n", + "i: 47\n", + "unique: 47\n", + "i: 48\n", + "unique: 48\n", + "i: 49\n", + "unique: 49\n", + "i: 50\n", + "unique: 50\n", + "i: 51\n", + "unique: 51\n", + "i: 52\n", + "unique: 52\n", + "i: 53\n", + "unique: 53\n", + "i: 54\n", + "unique: 54\n", + "i: 55\n", + "unique: 55\n", + "i: 56\n", + "unique: 56\n", + "i: 57\n", + "unique: 57\n", + "i: 58\n", + "unique: 58\n", + "i: 59\n", + "unique: 59\n", + "i: 60\n", + "unique: 60\n", + "i: 61\n", + "unique: 61\n", + "i: 62\n", + "unique: 62\n", + "i: 63\n", + "unique: 63\n", + "i: 64\n", + "unique: 64\n", + "i: 65\n", + "unique: 65\n", + "i: 66\n", + "unique: 66\n", + "i: 67\n", + "unique: 67\n", + "i: 68\n", + "unique: 68\n", + "i: 69\n", + "unique: 69\n", + "i: 70\n", + "unique: 70\n", + "i: 71\n", + "unique: 71\n", + "i: 72\n", + "unique: 72\n", + "i: 73\n", + "unique: 73\n", + "i: 74\n", + "unique: 74\n", + "i: 75\n", + "unique: 75\n", + "i: 76\n", + "unique: 76\n", + "i: 77\n", + "unique: 77\n", + "i: 78\n", + "unique: 78\n", + "i: 79\n", + "unique: 79\n", + "i: 80\n", + "unique: 80\n", + "i: 81\n", + "unique: 81\n", + "i: 82\n", + "unique: 82\n", + "i: 83\n", + "unique: 83\n", + "i: 84\n", + "unique: 84\n", + "i: 85\n", + "unique: 85\n", + "i: 86\n", + "unique: 86\n", + "i: 87\n", + "unique: 87\n", + "i: 88\n", + "unique: 88\n", + "i: 89\n", + "unique: 89\n", + "i: 90\n", + "unique: 90\n", + "i: 91\n", + "unique: 91\n", + "i: 93\n", + "unique: 93\n", + "i: 94\n", + "unique: 94\n", + "i: 95\n", + "unique: 95\n", + "i: 96\n", + "unique: 96\n", + "i: 97\n", + "unique: 97\n", + "i: 98\n", + "unique: 98\n", + "i: 99\n", + "unique: 99\n", + "i: 100\n", + "unique: 100\n", + "i: 101\n", + "unique: 101\n", + "i: 102\n", + "unique: 102\n", + "i: 103\n", + "unique: 103\n", + "i: 104\n", + "unique: 104\n", + "i: 105\n", + "unique: 105\n", + "i: 106\n", + "unique: 106\n", + "i: 107\n", + "unique: 107\n", + "i: 108\n", + "unique: 108\n", + "i: 109\n", + "unique: 109\n", + "i: 110\n", + "unique: 110\n", + "i: 111\n", + "unique: 111\n", + "i: 112\n", + "unique: 112\n", + "i: 113\n", + "unique: 113\n", + "i: 114\n", + "unique: 114\n", + "i: 115\n", + "unique: 115\n", + "i: 116\n", + "unique: 116\n", + "i: 117\n", + "unique: 117\n", + "i: 118\n", + "unique: 118\n", + "i: 119\n", + "unique: 119\n", + "i: 120\n", + "unique: 120\n", + "i: 121\n", + "unique: 121\n", + "i: 122\n", + "unique: 122\n", + "i: 123\n", + "unique: 123\n", + "i: 124\n", + "unique: 124\n", + "i: 125\n", + "unique: 125\n", + "map_fused_neurons: []\n", + "Calculating the number of neurons not found...\n", "Neurons found: 124\n", "Neurons fused: 0\n", "Neurons not found: 0\n", @@ -157,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -168,145 +425,66 @@ { "data": { "text/plain": [ - "" + "dtype('int32')" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected,name='connected')\n", + "connected.dtype" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -320,7 +498,7 @@ { "data": { "text/plain": [ - "(25, 64, 64)" + "dtype('int64')" ] }, "execution_count": 9, @@ -329,14 +507,12 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" + "gt_labels_resized.dtype" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -353,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -374,15 +550,7 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" - ] - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] From 6a94022a89114b28ba8468e987311f5246c2b9d5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 301/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../code_models/model_instance_seg.py | 2 +- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 +- .../dev_scripts/evaluate_labels.py | 282 +++++++-- notebooks/assess_instance.ipynb | 553 ++++++++---------- 5 files changed, 564 insertions(+), 351 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 54c453d2..40c0cfc8 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -136,7 +136,7 @@ def voronoi_otsu( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return instance + return np.array(instance) def binary_connected( diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b4436ccb..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,15 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -22,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -29,17 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 - print(f"i: {i}") + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - print(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -47,14 +228,20 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if len(tmp_map) == 1: # map to only one true neuron -> found neuron map_labels_existing.append(tmp_map[0]) @@ -62,17 +249,21 @@ def map_labels(labels, model_labels): # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map - # print(f"map_labels_existing: {map_labels_existing}") - print(f"map_fused_neurons: {map_fused_neurons}") - # print(f"new_labels: {new_labels}") + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -82,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -103,7 +296,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -113,7 +306,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) @@ -123,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -132,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -148,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -183,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -230,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 6e6a9b5f..d521c395 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -115,279 +110,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "i: 1\n", - "unique: 1\n", - "i: 2\n", - "unique: 2\n", - "i: 3\n", - "unique: 3\n", - "i: 4\n", - "unique: 4\n", - "i: 5\n", - "unique: 5\n", - "i: 6\n", - "unique: 6\n", - "i: 7\n", - "unique: 7\n", - "i: 8\n", - "unique: 8\n", - "i: 9\n", - "unique: 9\n", - "i: 10\n", - "unique: 10\n", - "i: 11\n", - "unique: 11\n", - "i: 12\n", - "unique: 12\n", - "i: 13\n", - "unique: 13\n", - "i: 14\n", - "unique: 14\n", - "i: 15\n", - "unique: 15\n", - "i: 16\n", - "unique: 16\n", - "i: 17\n", - "unique: 17\n", - "i: 18\n", - "unique: 18\n", - "i: 19\n", - "unique: 19\n", - "i: 20\n", - "unique: 20\n", - "i: 21\n", - "unique: 21\n", - "i: 22\n", - "unique: 22\n", - "i: 23\n", - "unique: 23\n", - "i: 24\n", - "unique: 24\n", - "i: 25\n", - "unique: 25\n", - "i: 26\n", - "unique: 26\n", - "i: 27\n", - "unique: 27\n", - "i: 28\n", - "unique: 28\n", - "i: 29\n", - "unique: 29\n", - "i: 30\n", - "unique: 30\n", - "i: 31\n", - "unique: 31\n", - "i: 32\n", - "unique: 32\n", - "i: 33\n", - "unique: 33\n", - "i: 34\n", - "unique: 34\n", - "i: 35\n", - "unique: 35\n", - "i: 36\n", - "unique: 36\n", - "i: 37\n", - "unique: 37\n", - "i: 38\n", - "unique: 38\n", - "i: 39\n", - "unique: 39\n", - "i: 40\n", - "unique: 40\n", - "i: 41\n", - "unique: 41\n", - "i: 42\n", - "unique: 42\n", - "i: 43\n", - "unique: 43\n", - "i: 44\n", - "unique: 44\n", - "i: 45\n", - "unique: 45\n", - "i: 46\n", - "unique: 46\n", - "i: 47\n", - "unique: 47\n", - "i: 48\n", - "unique: 48\n", - "i: 49\n", - "unique: 49\n", - "i: 50\n", - "unique: 50\n", - "i: 51\n", - "unique: 51\n", - "i: 52\n", - "unique: 52\n", - "i: 53\n", - "unique: 53\n", - "i: 54\n", - "unique: 54\n", - "i: 55\n", - "unique: 55\n", - "i: 56\n", - "unique: 56\n", - "i: 57\n", - "unique: 57\n", - "i: 58\n", - "unique: 58\n", - "i: 59\n", - "unique: 59\n", - "i: 60\n", - "unique: 60\n", - "i: 61\n", - "unique: 61\n", - "i: 62\n", - "unique: 62\n", - "i: 63\n", - "unique: 63\n", - "i: 64\n", - "unique: 64\n", - "i: 65\n", - "unique: 65\n", - "i: 66\n", - "unique: 66\n", - "i: 67\n", - "unique: 67\n", - "i: 68\n", - "unique: 68\n", - "i: 69\n", - "unique: 69\n", - "i: 70\n", - "unique: 70\n", - "i: 71\n", - "unique: 71\n", - "i: 72\n", - "unique: 72\n", - "i: 73\n", - "unique: 73\n", - "i: 74\n", - "unique: 74\n", - "i: 75\n", - "unique: 75\n", - "i: 76\n", - "unique: 76\n", - "i: 77\n", - "unique: 77\n", - "i: 78\n", - "unique: 78\n", - "i: 79\n", - "unique: 79\n", - "i: 80\n", - "unique: 80\n", - "i: 81\n", - "unique: 81\n", - "i: 82\n", - "unique: 82\n", - "i: 83\n", - "unique: 83\n", - "i: 84\n", - "unique: 84\n", - "i: 85\n", - "unique: 85\n", - "i: 86\n", - "unique: 86\n", - "i: 87\n", - "unique: 87\n", - "i: 88\n", - "unique: 88\n", - "i: 89\n", - "unique: 89\n", - "i: 90\n", - "unique: 90\n", - "i: 91\n", - "unique: 91\n", - "i: 93\n", - "unique: 93\n", - "i: 94\n", - "unique: 94\n", - "i: 95\n", - "unique: 95\n", - "i: 96\n", - "unique: 96\n", - "i: 97\n", - "unique: 97\n", - "i: 98\n", - "unique: 98\n", - "i: 99\n", - "unique: 99\n", - "i: 100\n", - "unique: 100\n", - "i: 101\n", - "unique: 101\n", - "i: 102\n", - "unique: 102\n", - "i: 103\n", - "unique: 103\n", - "i: 104\n", - "unique: 104\n", - "i: 105\n", - "unique: 105\n", - "i: 106\n", - "unique: 106\n", - "i: 107\n", - "unique: 107\n", - "i: 108\n", - "unique: 108\n", - "i: 109\n", - "unique: 109\n", - "i: 110\n", - "unique: 110\n", - "i: 111\n", - "unique: 111\n", - "i: 112\n", - "unique: 112\n", - "i: 113\n", - "unique: 113\n", - "i: 114\n", - "unique: 114\n", - "i: 115\n", - "unique: 115\n", - "i: 116\n", - "unique: 116\n", - "i: 117\n", - "unique: 117\n", - "i: 118\n", - "unique: 118\n", - "i: 119\n", - "unique: 119\n", - "i: 120\n", - "unique: 120\n", - "i: 121\n", - "unique: 121\n", - "i: 122\n", - "unique: 122\n", - "i: 123\n", - "unique: 123\n", - "i: 124\n", - "unique: 124\n", - "i: 125\n", - "unique: 125\n", - "map_fused_neurons: []\n", - "Calculating the number of neurons not found...\n", - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -414,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -428,66 +165,177 @@ "dtype('int32')" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')\n", + "viewer.add_labels(connected, name=\"connected\")\n", "connected.dtype" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -501,7 +349,7 @@ "dtype('int64')" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -512,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -522,42 +370,155 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { From d9b632cb1d65e83849b17e1391ea9be9ddb13213 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 302/577] black --- .../code_models/model_instance_seg.py | 21 ++++++++---- napari_cellseg3d/code_models/model_workers.py | 4 ++- .../code_plugins/plugin_model_inference.py | 8 +++-- napari_cellseg3d/config.py | 2 ++ .../dev_scripts/evaluate_labels.py | 33 +++++++++++-------- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 40c0cfc8..40684cb8 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -33,7 +33,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None + widget_parent: QWidget = None, ): """ Methods for instance segmentation @@ -56,7 +56,14 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), + ui.Slider( + 0, + 100, + 1, + divide_factor=100, + text_label="", + parent=None, + ), ) self.sliders.append(getattr(self, widget)) @@ -373,13 +380,13 @@ def fill(lst, n=len(properties) - 1): class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -419,13 +426,13 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -456,7 +463,7 @@ def __init__(self, widget_parent): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 367be2f0..2d4ba51a 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -544,7 +544,9 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct + instance_labels = np.swapaxes( + instance_labels, 0, 2 + ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index ff173b43..9fa1e9cf 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -557,7 +557,9 @@ def start(self): self.instance_config = config.InstanceSegConfig( enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], ) self.post_process_config = config.PostProcessConfig( @@ -726,7 +728,9 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index d1d8ab88..1728c51c 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -116,11 +116,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class PostProcessConfig: zoom: Zoom = Zoom() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..1aa52932 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -10,11 +10,14 @@ PERCENT_CORRECT = 0.7 + @dataclass class LabelInfo: gt_index: int model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + best_model_label_coverage: float = ( + 0.0 # ratio of pixels of the gt label correctly labelled + ) overall_gt_label_coverage: float = 0.0 # true positive ration of the model def get_correct_ratio(self): @@ -24,16 +27,25 @@ def get_correct_ratio(self): else: return None + def eval_model(gt_labels, model_labels, print_report=False): - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + report_list, new_labels, fused_labels = create_label_report( + gt_labels, model_labels + ) per_label_perfs = [] for report in report_list: if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + log.info( + f"Label {report.gt_index} : {report.model_labels_id_and_status}" + ) + log.info( + f"Best model label coverage : {report.best_model_label_coverage}" + ) + log.info( + f"Overall gt label coverage : {report.overall_gt_label_coverage}" + ) perf = report.get_correct_ratio() if perf is not None: @@ -43,8 +55,6 @@ def eval_model(gt_labels, model_labels, print_report=False): return per_label_perfs.mean(), new_labels, fused_labels - - def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -63,7 +73,6 @@ def create_label_report(gt_labels, model_labels): The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ - map_labels_existing = [] map_fused_neurons = {} "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" @@ -135,9 +144,7 @@ def create_label_report(gt_labels, model_labels): # log.debug(ratio) ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) + info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() best_model_lab_id = model_lab_on_gt[ np.argmax(ratio_model_lab_for_given_gt_lab) @@ -369,9 +376,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: From e6699f0936e753916a4474ae600248accd5394f9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 303/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 564 +++++++++--------- notebooks/assess_instance.ipynb | 290 ++++----- 2 files changed, 385 insertions(+), 469 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 1aa52932..3082e79f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,275 +1,15 @@ import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = ( - 0.0 # ratio of pixels of the gt label correctly labelled - ) - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report( - gt_labels, model_labels - ) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info( - f"Label {report.gt_index} : {report.model_labels_id_and_status}" - ) - log.info( - f"Best model label coverage : {report.best_model_label_coverage}" - ) - log.info( - f"Overall gt label coverage : {report.overall_gt_label_coverage}" - ) - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False ): """Evaluate the model performance. Parameters @@ -278,7 +18,7 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - do_print : bool + print_details : bool If True, print the results. visualize : bool If True, visualize the results. @@ -305,7 +45,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels + labels, model_labels, threshold_correct ) # calculate the number of neurons individually found @@ -351,33 +91,30 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") + log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") + log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") + log.info( + f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + ) log.info( - "Mean true positive ratio of the model: ", + f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" ) - log.info(mean_true_positive_ratio_model) log.info( - "Mean ratio of the neurons pixels correctly labelled: ", + f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" ) - log.info(mean_ratio_pixel_found) log.info( - "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" ) - log.info(mean_ratio_pixel_found_fused) log.info( - "Mean true positive ratio of the model for fused neurons: ", + f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" ) - log.info(mean_true_positive_ratio_model_fused) - log.info("Mean ratio of false pixel in artefacts: ") - log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -436,6 +173,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file @@ -464,6 +276,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons # if __name__ == "__main__": # """ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index d521c395..4bf89452 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,9 +4,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -22,6 +19,7 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", + " to_semantic,\n", ")" ] }, @@ -29,9 +27,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -50,12 +45,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -72,9 +69,7 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" + "viewer.add_labels(gt_labels_resized, name=\"gt\")" ] }, { @@ -84,9 +79,33 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -98,7 +117,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -110,48 +143,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "2\n" ] - }, - { - "data": { - "text/plain": [ - "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)\n", + "print(np.unique(gt_labels_resized).shape[0])" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -162,23 +168,22 @@ { "data": { "text/plain": [ - "dtype('int32')" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized,thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -190,21 +195,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:05,891 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -217,18 +225,10 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" + "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -251,21 +251,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:05,995 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -278,25 +281,17 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", + " prediction_resized, thres_small=2, rem_seed_thres=1\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -304,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -318,24 +313,24 @@ "(25, 64, 64)" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", + "voronoi = remove_small_objects(voronoi, 2)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { @@ -349,7 +344,7 @@ "dtype('int64')" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -360,104 +355,35 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(voronoi, return_counts=True)" + "# np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(gt_labels_resized, return_counts=True)" + "# np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { @@ -469,21 +395,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:06,360 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -496,18 +425,10 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -518,14 +439,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, "outputs": [], From 596c2d6c9e9e03d1849662bd643e6d96ed6add72 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 304/577] Added pre-commit hooks --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 834a225e..3189e9c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pre-commit pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From bdbb2dcdcd6680a115d4e346b3af8a323f711855 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 305/577] Enfore pre-commit style --- .gitignore | 4 - .../_tests/test_plugin_inference.py | 2 - .../code_models/model_instance_seg.py | 8 +- .../code_plugins/plugin_model_inference.py | 3 - .../code_plugins/plugin_utilities.py | 4 +- napari_cellseg3d/config.py | 3 - .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 23 ++++-- notebooks/assess_instance.ipynb | 79 +++++++++++++------ 10 files changed, 76 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index df43b4fa..e86beea4 100644 --- a/.gitignore +++ b/.gitignore @@ -106,7 +106,3 @@ notebooks/full_plot.html *.png *.prof -#include test data -!napari_cellseg3d/_tests/res/test.tif -!napari_cellseg3d/_tests/res/test.png -!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 584ffd3b..e15958e6 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,8 +7,6 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST - - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 40684cb8..c4b10adf 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,5 +1,3 @@ -from __future__ import division -from __future__ import print_function from dataclasses import dataclass from typing import List import numpy as np @@ -10,6 +8,7 @@ from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread + # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -539,14 +538,13 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 9fa1e9cf..302a52c9 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,9 +9,6 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 5463a4ff..fdcad6d3 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,7 +2,9 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # local import napari_cellseg3d.interface as ui diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 1728c51c..6df82043 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,10 +7,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3082e79f..a972fa69 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -5,11 +5,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -91,9 +95,15 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") - log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") - log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) if print_details: log.info(f"Neurons found: {neurons_found}") @@ -131,7 +141,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -276,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 4bf89452..b8810301 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -47,7 +47,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -96,7 +96,10 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", - "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" ] }, { @@ -145,7 +148,7 @@ "text": [ "(25, 64, 64)\n", "(25, 64, 64)\n", - "2\n" + "125\n" ] } ], @@ -168,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -177,7 +180,7 @@ } ], "source": [ - "connected = binary_connected(prediction_resized,thres_small=2)\n", + "connected = binary_connected(prediction_resized, thres_small=2)\n", "viewer.add_labels(connected, name=\"connected\")" ] }, @@ -195,24 +198,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,891 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -225,7 +228,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -251,24 +262,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,995 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -281,7 +292,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -395,24 +414,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,360 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -425,7 +444,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 15, From 050d7264d07101b54c8f69d62247e4f2bf36fa9b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 306/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index e86beea4..755de742 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,4 @@ notebooks/full_plot.html *.png *.prof + From 363cfac5ee4229b2a65d10caf372cc3a7ddab0f4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 307/577] Version bump --- napari_cellseg3d/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 11e8de0e..736c7f72 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1,2 @@ __version__ = "0.0.2rc6" + From 6f39e36b69bb90e1c4e2d24421b0419c9be5b035 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 308/577] Updated project files --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9c5adda7..c94c4ee4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ dev = [ "ruff", "tuna", "pre-commit", + ] docs = [ "sphinx", @@ -115,3 +116,4 @@ test = [ "tox", "twine", ] + From 14418430b8faa7faf89e540ae7f9890f97d44cba Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 09:45:17 +0200 Subject: [PATCH 309/577] Fixed missing parent error --- napari_cellseg3d/code_models/model_instance_seg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index c4b10adf..5f265dfd 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -456,7 +456,7 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent): + def __init__(self, widget_parent=None): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, From 004d8499632f703baa7511829f8fd38b07556cee Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 310/577] Fixed wrong value in instance sliders --- .../code_models/model_instance_seg.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 5f265dfd..979f861c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -138,6 +138,9 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -415,8 +418,8 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( image, - self.sliders[0].value(), - self.sliders[1].value(), + self.sliders[0].slider_value, + self.sliders[1].slider_value, self.counters[0].value(), self.counters[1].value(), ) @@ -449,7 +452,7 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( - image, self.sliders[0].value(), self.counters[0].value() + image, self.sliders[0].slider_value, self.counters[0].value() ) @@ -509,7 +512,7 @@ def __init__(self, parent=None): """ super().__init__(parent) self.method_choice = ui.DropdownMenu( - INSTANCE_SEGMENTATION_METHOD_LIST.keys() + list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) self.methods = {} """Contains the instance of the method, with its name as key""" @@ -528,7 +531,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -538,8 +541,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() @@ -563,9 +568,7 @@ def run_method(self, volume): Returns: processed image from self._method """ - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() + method = self.methods[self.method_choice.currentText()] return method.run_method(volume) From a323cb9e97a79665a5b11c980c1255a40e4381f4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 311/577] Removing dask-image --- .gitignore | 1 + napari_cellseg3d/utils.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 755de742..f8547d92 100644 --- a/.gitignore +++ b/.gitignore @@ -107,3 +107,4 @@ notebooks/full_plot.html *.prof +*.prof diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index a52c3de9..e09f12ba 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,7 +2,6 @@ import warnings from datetime import datetime from pathlib import Path - import numpy as np from skimage import io from skimage.filters import gaussian @@ -272,7 +271,6 @@ def annotation_to_input(label_ermito): anno = normalize_x(anno[np.newaxis, :, :, :]) return anno - # def check_csv(project_path, ext): # if not Path(Path(project_path) / Path(project_path).name).is_file(): # cols = [ From 4be3b274cefe72e0541708ebb0d373607198dabc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 312/577] Fixed erroneous dtype conversion --- .../code_models/model_instance_seg.py | 13 ++++++++-- .../code_plugins/plugin_convert.py | 25 +++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 979f861c..436135a1 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -137,12 +137,12 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - semantic = np.squeeze(volume) + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) instance = cle.voronoi_otsu_labeling( - semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) return np.array(instance) @@ -489,6 +489,15 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + + ################ + # For debugging + # import napari + # view = napari.Viewer() + # view.add_image(image) + # napari.run() + ################ + return self.function( image, self.counters[0].value(), diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index a04d9f09..3a60dff0 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -165,19 +165,18 @@ def _start(self): f"isotropic_{layer.name}", ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): From 631ec360a5392ebcc057db23fed9042608584a8a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 313/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 0abcf387..474e1d7f 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,8 +1,8 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np +from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, Utilities, @@ -24,9 +24,4 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) - if utils_name == "Convert to instance labels": - # to avoid issues with Voronoi-Otsu missing runtime - menu = widget.utils_widgets[i].instance_widgets.method_choice - menu.setCurrentIndex(menu.currentIndex() + 1) - widget.utils_widgets[i]._start() From 6abbec8e51bc912af92aa5e60ec58587965a3696 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:02:47 +0200 Subject: [PATCH 314/577] Update tox.ini Added pocl for testing on GH Actions --- tox.ini | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 87338cd8..46d84b40 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = magicgui pytest-qt qtpy -; pyopencl[pocl] + pocl +; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 4c6288f8876a4a3f8e9c9ee7535feb2311d8693d Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 315/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 46d84b40..6ba5efac 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl + pocl-binary-distribution ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From af9d2cacf5c3bd6c414d4c784ca9bc1dd7be820a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:18:52 +0200 Subject: [PATCH 316/577] Found existing pocl --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 6ba5efac..ee946a73 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl-binary-distribution + pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From c5bbce59a55cb3cb3161d37158fc55f3052fcd2d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:41:23 +0200 Subject: [PATCH 317/577] Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime --- napari_cellseg3d/_tests/test_plugin_utils.py | 5 +++++ tox.ini | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 474e1d7f..253f51bc 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -24,4 +24,9 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + widget.utils_widgets[i]._start() diff --git a/tox.ini b/tox.ini index ee946a73..40a2a7a0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pyopencl[pocl] +; pyopencl[pocl] ; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 0d36eb7e571c8e1c562cdf926411ebc8c6cdadb1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 318/577] Relabeling tests --- .gitignore | 6 +- .../_tests/test_labels_correction.py | 3 +- .../dev_scripts/artefact_labeling.py | 93 +++++++++---------- .../dev_scripts/correct_labels.py | 75 ++++++++++----- 4 files changed, 101 insertions(+), 76 deletions(-) diff --git a/.gitignore b/.gitignore index f8547d92..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -106,5 +106,7 @@ notebooks/full_plot.html *.png *.prof - -*.prof +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..9d4e7801 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,7 +1,6 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 9a344545..bf724a46 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,7 +1,5 @@ import numpy as np -from tifffile import imread -from tifffile import imwrite -from pathlib import Path +from tifffile import imwrite, imread import scipy.ndimage as ndimage import os import napari @@ -64,7 +62,7 @@ def map_labels(labels, artefacts): def make_labels( - path_image, + image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -76,7 +74,7 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - path_image : str + image : str Path to image. path_labels_out : str Path of the output labelled image. @@ -96,7 +94,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(path_image) + # image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -126,28 +124,26 @@ def make_labels( ) -def select_image_by_labels( - path_image, path_labels, path_image_out, label_values -): +def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - path_image : str - Path to image. - path_labels : str - Path to labels. + image : np.array + image. + labels : np.array + labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - image = imread(path_image) - labels = imread(path_labels) + # image = imread(image) + # labels = imread(labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smalles cube that contains all the none zero pixel of an 3d image +# select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -165,16 +161,15 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(path_image, path_image_out): +def crop_image_path(image, path_image_out): """Crop image. Parameters ---------- - path_image : str - Path to image. + image : np.array + image path_image_out : str Path of the output image. """ - image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -307,8 +302,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image_path, - labels_path, + image, + labels, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -317,10 +312,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image_path : str - Path to image file. - labels_path : str - Path to label image file with each neurons labelled as a different value. + image : np.array + image for artefact detection. + labels : np.array + label image array with each neurons labelled as a different int value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -330,9 +325,6 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ - image = imread(image_path) - labels = imread(labels_path) - artefacts = make_artefact_labels( image, labels, @@ -352,11 +344,12 @@ def visualize_images(paths): Parameters ---------- paths : list - List of paths to images to visualize. + List of images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - viewer.add_image(imread(path), name=os.path.basename(path)) + image = imread(path) + viewer.add_image(image) # wait for the user to close the viewer napari.run() @@ -416,22 +409,22 @@ def create_artefact_labels_from_folder( ) -if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] - print(f"REPO PATH : {repo_path}") - paths = [ - "dataset_clean/cropped_visual/train", - "dataset_clean/cropped_visual/val", - "dataset_clean/somatomotor", - "dataset_clean/visual_tif", - ] - for data_path in paths: - path = str(repo_path / data_path) - print(path) - create_artefact_labels_from_folder( - path, - do_visualize=False, - threshold_artefact_brightness_percent=20, - threshold_artefact_size_percent=1, - contrast_power=20, - ) +# if __name__ == "__main__": +# repo_path = Path(__file__).resolve().parents[1] +# print(f"REPO PATH : {repo_path}") +# paths = [ +# "dataset_clean/cropped_visual/train", +# "dataset_clean/cropped_visual/val", +# "dataset_clean/somatomotor", +# "dataset_clean/visual_tif", +# ] +# for data_path in paths: +# path = str(repo_path / data_path) +# print(path) +# create_artefact_labels_from_folder( +# path, +# do_visualize=False, +# threshold_artefact_brightness_percent=20, +# threshold_artefact_size_percent=1, +# contrast_power=20, +# ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index cd09754e..50f2e47a 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -4,6 +4,7 @@ import scipy.ndimage as ndimage import napari from pathlib import Path +from functools import partial import time import warnings from napari.qt.threading import thread_worker @@ -85,13 +86,16 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact): +def ask_labels(unique_artefact, test=False): global returns returns = [] - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + if not test: + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + else: + i_labels_to_add_tmp = [0] if i_labels_to_add_tmp == [0]: print("no label added") @@ -135,7 +139,13 @@ def ask_labels(unique_artefact): def relabel( - image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 + image_path, + label_path, + go_fast=False, + check_for_unicity=True, + delay=0.3, + viewer=None, + test=False, ): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters @@ -150,6 +160,8 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 + viewer : napari.Viewer, optional + the napari viewer, by default None """ global returns @@ -164,9 +176,10 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + if not test: + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -186,15 +199,22 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t = threading.Thread( + target=partial(ask_labels, test=test), args=(unique_artefact,) + ) t.start() artefact_copy = np.where( np.isin(artefact, i_labels_to_add), 0, artefact ) - viewer = napari.view_image(image) + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + viewer.add_image(image, name="image") viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - napari.run() + if not test: + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones @@ -205,15 +225,26 @@ def relabel( np.isin(artefact, i_labels_to_add_tmp), artefact, 0 ) print("these labels will be added") - viewer = napari.view_image(image) - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") + if test: + viewer.close() + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + if not test: + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if test: + revert = "n" + viewer.close() if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) + if test: + break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -334,9 +365,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") - image_path = str(im_path / "image.tif") - gt_labels_path = str(im_path / "labels.tif") - - relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# if __name__ == "__main__": +# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") +# image_path = str(im_path / "image.tif") +# gt_labels_path = str(im_path / "labels.tif") +# +# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) From 02fc0d1f27b1d03b4aa542ba94c0403128033f2e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 319/577] Run full suite of pre-commit hooks --- README.md | 2 +- napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/_tests/pytest.ini | 2 +- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/model_instance_seg.py | 3 +-- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 2 +- 9 files changed, 29 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ca8d0931..ece6c6f4 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 814cca2e..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 \ No newline at end of file +qt_api=pyqt5 diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 253f51bc..ea12024a 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 436135a1..6d0dc13d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -14,8 +14,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -489,7 +489,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index a972fa69..ee9919b6 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,7 @@ +import napari import numpy as np import pandas as pd from tqdm import tqdm -import napari from napari_cellseg3d.utils import LOGGER as log From 7c32086b026dae73b7e28a44875ffbb60101b201 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:08:38 +0200 Subject: [PATCH 320/577] Enforce style --- napari_cellseg3d/__init__.py | 1 - napari_cellseg3d/_tests/test_plugin_inference.py | 1 + napari_cellseg3d/_tests/test_plugin_utils.py | 5 +---- napari_cellseg3d/code_models/model_instance_seg.py | 8 +++++--- napari_cellseg3d/code_models/models/unet/model.py | 1 + napari_cellseg3d/code_plugins/plugin_convert.py | 2 ++ napari_cellseg3d/code_plugins/plugin_review.py | 1 + napari_cellseg3d/code_plugins/plugin_utilities.py | 12 +++++------- napari_cellseg3d/config.py | 1 - napari_cellseg3d/interface.py | 3 +-- pyproject.toml | 1 - 11 files changed, 17 insertions(+), 19 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 736c7f72..11e8de0e 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1,2 +1 @@ __version__ = "0.0.2rc6" - diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index e15958e6..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,6 +7,7 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index ea12024a..cbfd97b2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -4,10 +4,7 @@ from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import ( - UTILITIES_WIDGETS, - Utilities, -) +from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 6d0dc13d..cc362eac 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import List + import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget @@ -9,14 +10,15 @@ from skimage.segmentation import watershed from tifffile import imread -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index 9591d054..ee566be7 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -5,6 +5,7 @@ create_decoders, create_encoders, ) +from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 3a60dff0..47351bbb 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,5 +1,6 @@ import warnings from pathlib import Path + import napari import numpy as np from qtpy.QtWidgets import QSizePolicy @@ -355,6 +356,7 @@ def _start(self): self.images_filepaths, ) + class ToInstanceUtils(BasePluginFolder): """ Widget to convert semantic labels to instance labels diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 7ed6c549..80855f4a 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -17,6 +17,7 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui +from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index fdcad6d3..0de51392 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -8,13 +8,11 @@ # local import napari_cellseg3d.interface as ui -from napari_cellseg3d.code_plugins.plugin_convert import ( - AnisoUtils, - RemoveSmallUtils, - ThresholdUtils, - ToInstanceUtils, - ToSemanticUtils, -) +from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils +from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils +from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils +from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils +from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 6df82043..737b53aa 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -9,7 +9,6 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d697245a..2ed0434b 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -7,8 +7,7 @@ # Qt # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt +from qtpy import QtCore from qtpy.QtCore import QObject from qtpy.QtCore import Qt from qtpy.QtCore import QUrl diff --git a/pyproject.toml b/pyproject.toml index c94c4ee4..bfa14ac5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,4 +116,3 @@ test = [ "tox", "twine", ] - From 736254d5b659b7894fc8b6e39843a8e42f0a876e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 321/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- .../_tests/test_plugin_inference.py | 1 + .../code_models/model_instance_seg.py | 57 +++++++++++++++---- .../code_plugins/plugin_convert.py | 1 + napari_cellseg3d/config.py | 9 ++- napari_cellseg3d/interface.py | 3 +- 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..584ffd3b 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,6 +8,7 @@ from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index cc362eac..2c308c5d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -8,12 +8,18 @@ from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed + +from skimage.filters import thresholding +from skimage.transform import resize +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -82,6 +88,42 @@ def run_method(self, image): raise NotImplementedError("Must be defined in child classes") +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -122,7 +164,6 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -390,7 +431,7 @@ def __init__(self, widget_parent=None): function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent, + # widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -436,7 +477,7 @@ def __init__(self, widget_parent=None): function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent, + # widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -475,8 +516,8 @@ def __init__(self, widget_parent=None): ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") # smoothness + self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -504,6 +545,7 @@ def run_method(self, image): self.counters[0].value(), self.counters[1].value(), # self.counters[2].value(), + ) @@ -518,7 +560,6 @@ def __init__(self, parent=None): Args: parent: parent widget - """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -528,14 +569,12 @@ def __init__(self, parent=None): """Contains the instance of the method, with its name as key""" self.instance_widgets = {} """Contains the lists of widgets for each methods, to show/hide""" - self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() def _build(self): group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) @@ -555,11 +594,11 @@ def _build(self): logger.debug( f"Caught runtime error {e}, most likely during testing" ) - self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): + for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: @@ -571,12 +610,10 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters - Args: volume: image data to run method on Returns: processed image from self._method - """ method = self.methods[self.method_choice.currentText()] return method.run_method(volume) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 47351bbb..33192fa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -13,6 +13,7 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 737b53aa..9ea836f2 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,7 +8,6 @@ import numpy as np from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -16,6 +15,12 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -123,7 +128,7 @@ class InstanceSegConfig: class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 2ed0434b..0a094d6b 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1048,12 +1048,13 @@ def __init__( self.label = make_label(name=label) self.valueChanged.connect(self._update_step) - def _update_step(self): + def _update_step(self): #FIXME check divide_factor if self.value() < 0.9: self.setSingleStep(0.01) else: self.setSingleStep(0.1) + @property def tooltips(self): return self.toolTip() From 7767110aab219a8f6a78cdbdc1445ca373549225 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 322/577] Disabled small removal in Voronoi-Otsu --- .../code_models/model_instance_seg.py | 40 +------------------ 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 2c308c5d..19b5f5ba 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -87,43 +87,6 @@ def __init__( def run_method(self, image): raise NotImplementedError("Must be defined in child classes") - -class InstanceMethod: - def __init__( - self, - name: str, - function: callable, - num_sliders: int, - num_counters: int, - ): - self.name = name - self.function = function - self.counters: List[ui.DoubleIncrementCounter] = [] - self.sliders: List[ui.Slider] = [] - if num_sliders > 0: - for i in range(num_sliders): - widget = f"slider_{i}" - setattr( - self, - widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), - ) - self.sliders.append(getattr(self, widget)) - - if num_counters > 0: - for i in range(num_counters): - widget = f"counter_{i}" - setattr( - self, - widget, - ui.DoubleIncrementCounter(label=""), - ) - self.counters.append(getattr(self, widget)) - - def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") - - @dataclass class ImageStats: volume: List[float] @@ -431,7 +394,7 @@ def __init__(self, widget_parent=None): function=binary_watershed, num_sliders=2, num_counters=2, - # widget_parent=widget_parent, + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -545,7 +508,6 @@ def run_method(self, image): self.counters[0].value(), self.counters[1].value(), # self.counters[2].value(), - ) From d7d6f42c14b48478270c7e7a9f627a515aa41f02 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 323/577] Added new docs for instance seg --- napari_cellseg3d/code_models/model_instance_seg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 19b5f5ba..d43625d8 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -51,7 +51,6 @@ def __init__( num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets - """ self.name = name self.function = function @@ -384,11 +383,11 @@ def fill(lst, n=len(properties) - 1): fill([len(properties)]), ) - class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" def __init__(self, widget_parent=None): + super().__init__( name=WATERSHED, function=binary_watershed, @@ -430,11 +429,11 @@ def run_method(self, image): self.counters[1].value(), ) - class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" def __init__(self, widget_parent=None): + super().__init__( name=CONNECTED_COMP, function=binary_connected, @@ -466,6 +465,7 @@ class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" def __init__(self, widget_parent=None): + super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, @@ -480,7 +480,6 @@ def __init__(self, widget_parent=None): self.counters[0].setMaximum(100) self.counters[0].setValue(2) self.counters[1].label.setText("Outline sigma") # smoothness - self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" @@ -585,4 +584,5 @@ def run_method(self, volume): VORONOI_OTSU: VoronoiOtsu, WATERSHED: Watershed, CONNECTED_COMP: ConnectedComponents, + } From e0e7a0fecd8981ff58d55184b13edcb68d0df81c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 324/577] isort --- napari_cellseg3d/_tests/test_plugin_inference.py | 1 - napari_cellseg3d/code_models/model_instance_seg.py | 3 ++- napari_cellseg3d/code_plugins/plugin_convert.py | 1 - napari_cellseg3d/config.py | 7 +------ 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 584ffd3b..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,7 +8,6 @@ from napari_cellseg3d.config import MODEL_LIST - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index d43625d8..6734b06f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,11 +4,11 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget + from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - from skimage.filters import thresholding from skimage.transform import resize # from skimage.measure import mesh_surface_area @@ -20,6 +20,7 @@ from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 33192fa4..47351bbb 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -13,7 +13,6 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 9ea836f2..a43ff700 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,6 +7,7 @@ import napari import numpy as np + from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -15,12 +16,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER From 09a6e3ba29f090e3adf676e0fe57cd6fc38a44fc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 325/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 45c3be1c..3becfaca 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,3 @@ [pytest] qt_api=pyqt5 + From 2f03481c2951e11aa0aadbda4dfa1807074f1dc7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 326/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- napari_cellseg3d/code_models/model_instance_seg.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 6734b06f..e1b2eb03 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -17,10 +17,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -71,6 +69,7 @@ def __init__( text_label="", parent=None, ), + ) self.sliders.append(getattr(self, widget)) @@ -384,11 +383,11 @@ def fill(lst, n=len(properties) - 1): fill([len(properties)]), ) + class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" def __init__(self, widget_parent=None): - super().__init__( name=WATERSHED, function=binary_watershed, @@ -434,13 +433,12 @@ class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" def __init__(self, widget_parent=None): - super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - # widget_parent=widget_parent, + widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -466,7 +464,6 @@ class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" def __init__(self, widget_parent=None): - super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, @@ -585,5 +582,4 @@ def run_method(self, volume): VORONOI_OTSU: VoronoiOtsu, WATERSHED: Watershed, CONNECTED_COMP: ConnectedComponents, - } From bb265100c741483ef34f12f303debd1e65e9f361 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 327/577] Fix inference --- napari_cellseg3d/code_models/model_instance_seg.py | 1 + napari_cellseg3d/code_plugins/plugin_model_inference.py | 1 + napari_cellseg3d/config.py | 6 +++++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e1b2eb03..19d87a6a 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -528,6 +528,7 @@ def __init__(self, parent=None): """Contains the instance of the method, with its name as key""" self.instance_widgets = {} """Contains the lists of widgets for each methods, to show/hide""" + self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 302a52c9..69b6244e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -557,6 +557,7 @@ def start(self): method=self.instance_widgets.methods[ self.instance_widgets.method_choice.currentText() ], + ) self.post_process_config = config.PostProcessConfig( diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index a43ff700..9a14d706 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -112,6 +112,10 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class InstanceSegConfig: @@ -123,7 +127,7 @@ class InstanceSegConfig: class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ From f9e47d393a195052137b8d4fc97e3ef29b6c2ee4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 328/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- napari_cellseg3d/code_models/model_instance_seg.py | 4 ++++ napari_cellseg3d/dev_scripts/artefact_labeling.py | 11 ++++++----- napari_cellseg3d/dev_scripts/correct_labels.py | 4 +--- napari_cellseg3d/dev_scripts/evaluate_labels.py | 2 +- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 19d87a6a..e279235e 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -50,6 +50,7 @@ def __init__( num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -519,6 +520,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) self.method_choice = ui.DropdownMenu( @@ -570,10 +572,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = self.methods[self.method_choice.currentText()] return method.run_method(volume) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..aaf345cf 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -12,7 +12,6 @@ # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) - """ New code by Yves Paychere Creates labels of artifacts in an image based on existing labels of neurons @@ -78,7 +77,7 @@ def make_labels( Parameters ---------- image : str - Path to image. + image array path_labels_out : str Path of the output labelled image. threshold_size : int, optional @@ -97,7 +96,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - # image = imread(image) + image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -107,6 +106,7 @@ def make_labels( image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( np.max(image_contrasted) - np.min(image_contrasted) ) + image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) @@ -126,7 +126,6 @@ def make_labels( image_contrasted.astype(np.float32), ) - def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters @@ -142,10 +141,12 @@ def select_image_by_labels(image, labels, path_image_out, label_values): """ # image = imread(image) # labels = imread(labels) + image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) + # select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) @@ -430,4 +431,4 @@ def create_artefact_labels_from_folder( # threshold_artefact_brightness_percent=20, # threshold_artefact_size_percent=1, # contrast_power=20, -# ) +# ) \ No newline at end of file diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2f079d09..4c52675c 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -18,7 +18,6 @@ # import sys # sys.path.append(str(Path(__file__) / "../../")) - """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -87,7 +86,6 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] - def ask_labels(unique_artefact, test=False): global returns returns = [] @@ -139,7 +137,6 @@ def ask_labels(unique_artefact, test=False): returns = [i_labels_to_add_tmp] print("close the napari window to continue") - def relabel( image_path, label_path, @@ -373,3 +370,4 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # gt_labels_path = str(im_path / "labels.tif") # # relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) + diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index ee9919b6..087a01bb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -503,4 +503,4 @@ def save_as_csv(results, path): # "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" # ) # -# evaluate_model_performance(labels, labels_model, visualize=True) +# evaluate_model_performance(labels, labels_model, visualize=True) \ No newline at end of file From ec2cdd05a773a7fe74c5afe23afa97e084841326 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 329/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- napari_cellseg3d/dev_scripts/evaluate_labels.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 087a01bb..e253eb2c 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -7,7 +7,6 @@ PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct - def evaluate_model_performance( labels, model_labels, @@ -47,7 +46,7 @@ def evaluate_model_performance( mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels, threshold_correct ) @@ -57,7 +56,7 @@ def evaluate_model_performance( # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) From 10444f8ac745673180033c1bb68c0c460c14708f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 330/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 - napari_cellseg3d/dev_scripts/correct_labels.py | 1 - napari_cellseg3d/dev_scripts/evaluate_labels.py | 6 ++++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index aaf345cf..69d6535d 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,5 +1,4 @@ import os - import napari import numpy as np import scipy.ndimage as ndimage diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 4c52675c..9fcb2a88 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -17,7 +17,6 @@ # import sys # sys.path.append(str(Path(__file__) / "../../")) - """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index e253eb2c..b74251f8 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,5 +1,7 @@ import napari import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm @@ -46,7 +48,7 @@ def evaluate_model_performance( mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels, threshold_correct ) @@ -56,7 +58,7 @@ def evaluate_model_performance( # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) From 55f6fdc8d6f7749d00f9d34632158e940adb36d3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 331/577] black --- napari_cellseg3d/code_models/model_instance_seg.py | 1 - napari_cellseg3d/code_plugins/plugin_model_inference.py | 1 - napari_cellseg3d/config.py | 2 ++ 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e279235e..b1d4d9b7 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -70,7 +70,6 @@ def __init__( text_label="", parent=None, ), - ) self.sliders.append(getattr(self, widget)) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 69b6244e..302a52c9 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -557,7 +557,6 @@ def start(self): method=self.instance_widgets.methods[ self.instance_widgets.method_choice.currentText() ], - ) self.post_process_config = config.PostProcessConfig( diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 9a14d706..f50b152c 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -112,11 +112,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class InstanceSegConfig: enabled: bool = False From 43ebdcd1cd60fdc22c3c693bcb6184727110cd98 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 332/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 188 +++++++++++++++++- 1 file changed, 186 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b74251f8..6065520a 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,5 @@ import napari import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm @@ -287,6 +285,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons ####################### # Slower version that was used for debugging From df4f9daf24b078971129f7dade7e1490ffa64d21 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 333/577] Enfore pre-commit style --- napari_cellseg3d/config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index f50b152c..ccfad955 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -119,11 +119,6 @@ class InstanceSegConfig: method: InstanceMethod = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: InstanceMethod = None - @dataclass class PostProcessConfig: From 6f1906e94f5f8308bd419b2a930dfc990af5e1b6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 334/577] Removing dask-image --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df43b4fa..ea3db7fc 100644 --- a/.gitignore +++ b/.gitignore @@ -110,3 +110,4 @@ notebooks/full_plot.html !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif + From dd58a06bf2bf4de3e0758373a3b63b9c67b2fc04 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 335/577] Fixed erroneous dtype conversion --- .../code_models/model_instance_seg.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index b1d4d9b7..412c87d7 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,13 +4,11 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget - from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.filters import thresholding -from skimage.transform import resize + # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread @@ -20,10 +18,6 @@ from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - - # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : @@ -86,6 +80,7 @@ def __init__( def run_method(self, image): raise NotImplementedError("Must be defined in child classes") + @dataclass class ImageStats: volume: List[float] @@ -126,6 +121,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -165,6 +161,8 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 + scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) + """ logger.debug( f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" @@ -429,6 +427,7 @@ def run_method(self, image): self.counters[1].value(), ) + class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" @@ -477,6 +476,7 @@ def __init__(self, widget_parent=None): ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 @@ -492,6 +492,7 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + ################ # For debugging # import napari @@ -536,6 +537,7 @@ def __init__(self, parent=None): def _build(self): group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) + try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) @@ -555,11 +557,11 @@ def _build(self): logger.debug( f"Caught runtime error {e}, most likely during testing" ) + self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: From 01d2473e37e8f943ed0a7dbfcbbc84682118c4e9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 336/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index cbfd97b2..7403f2b7 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,5 +1,4 @@ from pathlib import Path - import numpy as np from tifffile import imread From a1468ff8870e84b12a0fb28675625ac07a4becd5 Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 337/577] Update tox.ini --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index 40a2a7a0..87338cd8 100644 --- a/tox.ini +++ b/tox.ini @@ -37,6 +37,5 @@ deps = pytest-qt qtpy ; pyopencl[pocl] -; opencv-python commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 9d97ac2ca8df60061a55bc2ebc884a695c7050f4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:06:43 +0200 Subject: [PATCH 338/577] Added new pre-commit hooks --- .pre-commit-config.yaml | 1 - pyproject.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f9fe2853..e4bff318 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: -# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - id: check-yaml diff --git a/pyproject.toml b/pyproject.toml index bfa14ac5..9c5adda7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,6 @@ dev = [ "ruff", "tuna", "pre-commit", - ] docs = [ "sphinx", From 6f4e4b69be25d3bd484a337d9e73c1c14aff2e5a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 339/577] Run full suite of pre-commit hooks --- napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/code_models/model_instance_seg.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 412c87d7..c72bafe9 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,8 +15,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -492,7 +492,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari From 90cfbd51be07cbbfb198bf114b9f2626b452bf6f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 10:53:03 +0200 Subject: [PATCH 340/577] Enforce style --- .gitignore | 1 - napari_cellseg3d/_tests/pytest.ini | 1 - napari_cellseg3d/_tests/test_plugin_utils.py | 1 + napari_cellseg3d/config.py | 3 +-- napari_cellseg3d/dev_scripts/artefact_labeling.py | 5 +++-- napari_cellseg3d/dev_scripts/correct_labels.py | 3 ++- napari_cellseg3d/dev_scripts/evaluate_labels.py | 4 +++- napari_cellseg3d/interface.py | 3 +-- napari_cellseg3d/utils.py | 2 ++ 9 files changed, 13 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index ea3db7fc..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -110,4 +110,3 @@ notebooks/full_plot.html !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif - diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 3becfaca..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,3 +1,2 @@ [pytest] qt_api=pyqt5 - diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 7403f2b7..cbfd97b2 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,4 +1,5 @@ from pathlib import Path + import numpy as np from tifffile import imread diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index ccfad955..737b53aa 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,8 +7,8 @@ import napari import numpy as np - from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -119,7 +119,6 @@ class InstanceSegConfig: method: InstanceMethod = None - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 69d6535d..90048a60 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ import os + import napari import numpy as np import scipy.ndimage as ndimage @@ -125,6 +126,7 @@ def make_labels( image_contrasted.astype(np.float32), ) + def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters @@ -145,7 +147,6 @@ def select_image_by_labels(image, labels, path_image_out, label_values): imwrite(path_image_out, image.astype(np.float32)) - # select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) @@ -430,4 +431,4 @@ def create_artefact_labels_from_folder( # threshold_artefact_brightness_percent=20, # threshold_artefact_size_percent=1, # contrast_power=20, -# ) \ No newline at end of file +# ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 9fcb2a88..aacf08f8 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -85,6 +85,7 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] + def ask_labels(unique_artefact, test=False): global returns returns = [] @@ -136,6 +137,7 @@ def ask_labels(unique_artefact, test=False): returns = [i_labels_to_add_tmp] print("close the napari window to continue") + def relabel( image_path, label_path, @@ -369,4 +371,3 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # gt_labels_path = str(im_path / "labels.tif") # # relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) - diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 6065520a..bd2f0768 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -7,6 +7,7 @@ PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + def evaluate_model_performance( labels, model_labels, @@ -285,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### @@ -688,4 +690,4 @@ def save_as_csv(results, path): # "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" # ) # -# evaluate_model_performance(labels, labels_model, visualize=True) \ No newline at end of file +# evaluate_model_performance(labels, labels_model, visualize=True) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 0a094d6b..99f3b751 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1048,13 +1048,12 @@ def __init__( self.label = make_label(name=label) self.valueChanged.connect(self._update_step) - def _update_step(self): #FIXME check divide_factor + def _update_step(self): # FIXME check divide_factor if self.value() < 0.9: self.setSingleStep(0.01) else: self.setSingleStep(0.1) - @property def tooltips(self): return self.toolTip() diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index e09f12ba..a52c3de9 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,6 +2,7 @@ import warnings from datetime import datetime from pathlib import Path + import numpy as np from skimage import io from skimage.filters import gaussian @@ -271,6 +272,7 @@ def annotation_to_input(label_ermito): anno = normalize_x(anno[np.newaxis, :, :, :]) return anno + # def check_csv(project_path, ext): # if not Path(Path(project_path) / Path(project_path).name).is_file(): # cols = [ From 817e3cbb69aa34fa9bf48dc1c9981be60cfd301b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 14:29:40 +0200 Subject: [PATCH 341/577] Documentation update, crop contrast fix --- docs/res/welcome.rst | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 12a20630..892549a8 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -103,8 +103,6 @@ This plugin mainly uses the following libraries and software: * `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase -* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase - * A custom re-implementation of the `WNet model`_ by Xia and Kulis [#]_ .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ @@ -115,7 +113,7 @@ This plugin mainly uses the following libraries and software: .. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets .. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype - +.. _WNet model: https://arxiv.org/abs/1711.08506 .. rubric:: References From e72df6914ad21063cbb903da7c9239daa1b2a445 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 11:09:12 +0200 Subject: [PATCH 342/577] Updated hooks --- .pre-commit-config.yaml | 1 + napari_cellseg3d/_tests/test_plugin_utils.py | 6 +- .../code_models/model_instance_seg.py | 8 +-- napari_cellseg3d/code_models/model_workers.py | 7 ++- .../code_models/models/unet/model.py | 1 - .../code_plugins/plugin_convert.py | 63 ++++++++++--------- .../code_plugins/plugin_model_inference.py | 12 ++-- .../code_plugins/plugin_review.py | 1 - .../code_plugins/plugin_utilities.py | 16 ++--- .../dev_scripts/artefact_labeling.py | 3 +- .../dev_scripts/correct_labels.py | 8 +-- napari_cellseg3d/interface.py | 50 +++++++-------- 12 files changed, 88 insertions(+), 88 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4bff318..f9fe2853 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: +# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - id: check-yaml diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index cbfd97b2..0abcf387 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -3,8 +3,10 @@ import numpy as np from tifffile import imread -from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS +from napari_cellseg3d.code_plugins.plugin_utilities import ( + UTILITIES_WIDGETS, + Utilities, +) def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index c72bafe9..60f8bbda 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,8 +4,7 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label -from skimage.measure import regionprops +from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed @@ -14,9 +13,8 @@ from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -561,7 +559,7 @@ def _build(self): self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): + for name in self.instance_widgets: if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 2d4ba51a..30d37bbd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -51,9 +51,10 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ImageStats -from napari_cellseg3d.code_models.model_instance_seg import volume_stats +from napari_cellseg3d.code_models.model_instance_seg import ( + ImageStats, + volume_stats, +) logger = utils.LOGGER diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index ee566be7..9591d054 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -5,7 +5,6 @@ create_decoders, create_encoders, ) -from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 47351bbb..6c8370c1 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -4,15 +4,16 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import threshold -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceWidgets, + clear_small_objects, + threshold, + to_semantic, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -166,18 +167,19 @@ def _start(self): f"isotropic_{layer.name}", ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -343,18 +345,19 @@ def _start(self): show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 302a52c9..22867343 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,10 +9,14 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceMethod, + InstanceWidgets, +) +from napari_cellseg3d.code_models.model_workers import ( + InferenceResult, + InferenceWorker, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 80855f4a..7ed6c549 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -17,7 +17,6 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 0de51392..5463a4ff 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,17 +2,17 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget # local import napari_cellseg3d.interface as ui -from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils -from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils -from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils +from napari_cellseg3d.code_plugins.plugin_convert import ( + AnisoUtils, + RemoveSmallUtils, + ThresholdUtils, + ToInstanceUtils, + ToSemanticUtils, +) from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 90048a60..3f95e1a8 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -4,8 +4,7 @@ import numpy as np import scipy.ndimage as ndimage from skimage.filters import threshold_otsu -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from napari_cellseg3d.code_models.model_instance_seg import binary_watershed diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index aacf08f8..168990e1 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -8,8 +8,7 @@ import numpy as np import scipy.ndimage as ndimage from napari.qt.threading import thread_worker -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels @@ -227,10 +226,7 @@ def relabel( print("these labels will be added") if test: viewer.close() - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer + viewer = napari.view_image(image) if viewer is None else viewer if not test: viewer.add_labels(artefact_copy, name="labels added") napari.run() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 99f3b751..276f9214 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -8,32 +8,30 @@ # Qt # from qtpy.QtCore import QtWarningMsg from qtpy import QtCore -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt -from qtpy.QtCore import QUrl -from qtpy.QtGui import QCursor -from qtpy.QtGui import QDesktopServices -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QCheckBox -from qtpy.QtWidgets import QComboBox -from qtpy.QtWidgets import QDoubleSpinBox -from qtpy.QtWidgets import QFileDialog -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QGroupBox -from qtpy.QtWidgets import QHBoxLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QMenu -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QRadioButton -from qtpy.QtWidgets import QScrollArea -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QSlider -from qtpy.QtWidgets import QSpinBox -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtCore import QObject, Qt, QUrl +from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLayout, + QLineEdit, + QMenu, + QPushButton, + QRadioButton, + QScrollArea, + QSizePolicy, + QSlider, + QSpinBox, + QTextEdit, + QVBoxLayout, + QWidget, +) # Local from napari_cellseg3d import utils From 142ceee1f2a94dd3eb5083d4f80a8d4c486c211a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:11:22 +0100 Subject: [PATCH 343/577] Update setup.cfg --- setup.cfg | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/setup.cfg b/setup.cfg index 3a0bdaae..ede7724d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,30 @@ python_requires = >=3.8 package_dir = =. +# add your package requirements here +# the long list after monai is due to monai optional requirements... Not sure how to know in advance which readers it wil use +install_requires = + numpy + napari[all]>=0.4.14 + QtPy + opencv-python>=4.5.5 + dask-image>=0.6.0 + scikit-image>=0.19.2 + matplotlib>=3.4.1 + tifffile>=2022.2.9 + imageio-ffmpeg>=0.4.5 + torch>=1.11 + monai[nibabel,einops]>=0.9.0 + itk + tqdm + nibabel + pyclesperanto-prototype + scikit-image + pillow + tqdm + matplotlib + vispy>=0.9.6 + [options.packages.find] where = . From 10699f622e1db6fec7ff62bb29370e169a08789b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 344/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../dev_scripts/artefact_labeling.py | 12 +- .../dev_scripts/correct_labels.py | 28 +- .../dev_scripts/evaluate_labels.py | 303 ++++++++++++++++-- notebooks/assess_instance.ipynb | 281 +++++++++++++--- setup.cfg | 2 +- 5 files changed, 524 insertions(+), 102 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 3f95e1a8..102a7d35 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,10 +1,9 @@ import os import napari -import numpy as np -import scipy.ndimage as ndimage -from skimage.filters import threshold_otsu -from tifffile import imread, imwrite + +# import sys +# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed @@ -105,7 +104,6 @@ def make_labels( image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( np.max(image_contrasted) - np.min(image_contrasted) ) - image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) @@ -126,7 +124,9 @@ def make_labels( ) -def select_image_by_labels(image, labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 168990e1..c888378c 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -10,12 +10,13 @@ from napari.qt.threading import thread_worker from tifffile import imread, imwrite from tqdm import tqdm - -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import threading # import sys # sys.path.append(str(Path(__file__) / "../../")) + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -138,13 +139,7 @@ def ask_labels(unique_artefact, test=False): def relabel( - image_path, - label_path, - go_fast=False, - check_for_unicity=True, - delay=0.3, - viewer=None, - test=False, + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 ): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters @@ -175,10 +170,9 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - if not test: - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -205,11 +199,7 @@ def relabel( artefact_copy = np.where( np.isin(artefact, i_labels_to_add), 0, artefact ) - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer - viewer.add_image(image, name="image") + viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") if not test: diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index bd2f0768..f75ed6bd 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,19 +1,269 @@ import napari import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict +import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels def evaluate_model_performance( - labels, - model_labels, - threshold_correct=PERCENT_CORRECT, - print_details=False, - visualize=False, + labels, model_labels, do_print=False, visualize=False ): """Evaluate the model performance. Parameters @@ -95,36 +345,35 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info( - f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" - ) - - if print_details: - log.info(f"Neurons found: {neurons_found}") - log.info(f"Neurons fused: {neurons_fused}") - log.info(f"Neurons not found: {neurons_not_found}") - log.info(f"Artefacts found: {artefacts_found}") + if do_print: + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) log.info( - f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + "Mean true positive ratio of the model: ", ) + log.info(mean_true_positive_ratio_model) log.info( - f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" + "Mean ratio of the neurons pixels correctly labelled: ", ) + log.info(mean_ratio_pixel_found) log.info( - f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" + "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", ) + log.info(mean_ratio_pixel_found_fused) log.info( - f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" + "Mean true positive ratio of the model for fused neurons: ", ) + log.info(mean_true_positive_ratio_model_fused) log.info( - f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -141,7 +390,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) is False, + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0, ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b8810301..fa22c7b7 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -19,7 +19,6 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", - " to_semantic,\n", ")" ] }, @@ -45,14 +44,12 @@ }, "outputs": [ { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] } ], "source": [ @@ -69,7 +66,9 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")" + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" ] }, { @@ -198,24 +197,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -250,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -303,14 +299,138 @@ " 1.0)" ] }, - "execution_count": 10, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected, name=\"connected\")\n", + "connected.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=2, rem_seed_thres=1\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -318,7 +438,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -332,24 +452,24 @@ "(25, 64, 64)" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 2)\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -363,7 +483,7 @@ "dtype('int64')" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -374,35 +494,101 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels_resized, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { @@ -414,24 +600,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" + "2023-03-22 14:47:30,755 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" ] }, { @@ -444,18 +627,18 @@ { "data": { "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" ] }, - "execution_count": 15, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -466,7 +649,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { diff --git a/setup.cfg b/setup.cfg index ede7724d..78cc98ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=0.9.0 + monai[nibabel,einops]>=1.0.1 itk tqdm nibabel From e716653c8f9fcb74d65ebdf35a5bf8d164275ece Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 345/577] Enfore pre-commit style --- .gitignore | 4 - .../_tests/test_plugin_inference.py | 1 - .../code_models/model_instance_seg.py | 9 +- .../code_plugins/plugin_model_inference.py | 12 +- .../code_plugins/plugin_utilities.py | 4 +- napari_cellseg3d/config.py | 1 + .../dev_scripts/artefact_labeling.py | 38 +- .../dev_scripts/correct_labels.py | 12 +- .../dev_scripts/evaluate_labels.py | 471 +----------------- notebooks/assess_instance.ipynb | 50 +- 10 files changed, 88 insertions(+), 514 deletions(-) diff --git a/.gitignore b/.gitignore index df43b4fa..e86beea4 100644 --- a/.gitignore +++ b/.gitignore @@ -106,7 +106,3 @@ notebooks/full_plot.html *.png *.prof -#include test data -!napari_cellseg3d/_tests/res/test.tif -!napari_cellseg3d/_tests/res/test.png -!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..e15958e6 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,7 +7,6 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 60f8bbda..40a07893 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -7,6 +7,7 @@ from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed +from tifffile import imread # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -550,16 +551,14 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug( - f"Caught runtime error {e}, most likely during testing" - ) + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets: + for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 22867343..302a52c9 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,14 +9,10 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - InstanceMethod, - InstanceWidgets, -) -from napari_cellseg3d.code_models.model_workers import ( - InferenceResult, - InferenceWorker, -) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 5463a4ff..fdcad6d3 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,7 +2,9 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # local import napari_cellseg3d.interface as ui diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 737b53aa..6df82043 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -9,6 +9,7 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 102a7d35..04e288d8 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -412,22 +412,22 @@ def create_artefact_labels_from_folder( ) -# if __name__ == "__main__": -# repo_path = Path(__file__).resolve().parents[1] -# print(f"REPO PATH : {repo_path}") -# paths = [ -# "dataset_clean/cropped_visual/train", -# "dataset_clean/cropped_visual/val", -# "dataset_clean/somatomotor", -# "dataset_clean/visual_tif", -# ] -# for data_path in paths: -# path = str(repo_path / data_path) -# print(path) -# create_artefact_labels_from_folder( -# path, -# do_visualize=False, -# threshold_artefact_brightness_percent=20, -# threshold_artefact_size_percent=1, -# contrast_power=20, -# ) +if __name__ == "__main__": + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index c888378c..77835007 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -351,9 +351,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") -# -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index f75ed6bd..3eb62764 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -9,261 +9,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -345,15 +99,21 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") log.info( "Mean true positive ratio of the model: ", ) @@ -390,7 +150,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -723,193 +483,6 @@ def save_as_csv(results, path): # # return label_report_list, new_labels, map_fused_neurons -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index fa22c7b7..b2382c31 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -44,12 +44,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -197,21 +199,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -600,21 +605,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -627,15 +635,15 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 13, From e433cf326591776aec326d6d7404dd40e6cb7c57 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 346/577] Updated project files --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9c5adda7..c94c4ee4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ dev = [ "ruff", "tuna", "pre-commit", + ] docs = [ "sphinx", @@ -115,3 +116,4 @@ test = [ "tox", "twine", ] + From adc337abf2977834123070eac5ecc45299608da3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 347/577] Removing dask-image --- .gitignore | 2 ++ napari_cellseg3d/dev_scripts/convert.py | 3 ++- napari_cellseg3d/utils.py | 13 ++++++------- notebooks/full_plot.ipynb | 1 - setup.cfg | 1 + 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index e86beea4..f8547d92 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,5 @@ notebooks/full_plot.html *.png *.prof + +*.prof diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 641de627..479a07dd 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,7 +2,8 @@ import os import numpy as np -from tifffile import imread, imwrite +from tifffile import imread +from tifffile import imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" # output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index a52c3de9..ecb6a199 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -4,6 +4,8 @@ from pathlib import Path import numpy as np + +# from dask import delayed from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -455,13 +457,10 @@ def load_images( raise ValueError("If loading as a folder, filetype must be specified") if as_folder: - try: - images_original = tfl_imread(filename_pattern_original) - except ValueError: - LOGGER.error( - "Loading a stack this way is no longer supported. Use napari to load a stack." - ) - + raise NotImplementedError( + "Loading as folder not implemented yet. Use napari to load as folder" + ) + # images_original = dask_imread(filename_pattern_original) else: images_original = tfl_imread( filename_pattern_original diff --git a/notebooks/full_plot.ipynb b/notebooks/full_plot.ipynb index 5c640e1b..87f973f9 100644 --- a/notebooks/full_plot.ipynb +++ b/notebooks/full_plot.ipynb @@ -10,7 +10,6 @@ "import matplotlib.pyplot as plt\n", "import os\n", "import numpy as np\n", - "from PIL import Image\n", "from tifffile import imread" ] }, diff --git a/setup.cfg b/setup.cfg index 78cc98ce..6111ed7e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,6 +7,7 @@ package_dir = # add your package requirements here # the long list after monai is due to monai optional requirements... Not sure how to know in advance which readers it wil use +# FIXME remove dask install_requires = numpy napari[all]>=0.4.14 From 9342597fcf74bb666f485d3972d36b3ee67dfe6d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 348/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f9fe2853..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,14 +5,11 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - id: check-yaml - - id: check-added-large-files - - id: check-toml -# - repo: https://github.com/pycqa/isort -# rev: 5.12.0 -# hooks: -# - id: isort -# args: ["--profile", "black", --line-length=79] + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From cb9e3068f4b593243eaefaa19cdd170f2253d371 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 349/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- .../_tests/test_plugin_inference.py | 1 + .../code_models/model_instance_seg.py | 187 ++++++++++-------- napari_cellseg3d/code_models/model_workers.py | 21 +- .../code_plugins/plugin_convert.py | 35 ++-- .../code_plugins/plugin_model_inference.py | 14 +- napari_cellseg3d/config.py | 17 +- napari_cellseg3d/interface.py | 10 +- requirements.txt | 4 +- 8 files changed, 149 insertions(+), 140 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index e15958e6..212c4120 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -7,6 +7,7 @@ from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 40a07893..9e6877da 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,18 +4,21 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label, regionprops +from skimage.measure import label +from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from tifffile import imread +from skimage.filters import thresholding +from skimage.transform import resize # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between +from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import Singleton # from napari_cellseg3d.utils import sphericity_volume_area @@ -80,6 +83,42 @@ def run_method(self, image): raise NotImplementedError("Must be defined in child classes") +class InstanceMethod: + def __init__( + self, + name: str, + function: callable, + num_sliders: int, + num_counters: int, + ): + self.name = name + self.function = function + self.counters: List[ui.DoubleIncrementCounter] = [] + self.sliders: List[ui.Slider] = [] + if num_sliders > 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider(0, 100, 1, divide_factor=100, text_label=""), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label=""), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + + @dataclass class ImageStats: volume: List[float] @@ -120,32 +159,27 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - # remove_small_size: float, + remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant - Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation + remove_small_size (float): remove all objects smaller than the specified size in pixels Returns: Instance segmentation labels from Voronoi-Otsu method - """ - # remove_small_size (float): remove all objects smaller than the specified size in pixels - # semantic = np.squeeze(volume) - logger.debug( - f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" - ) + semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( - volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma + semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return np.array(instance) + return instance def binary_connected( @@ -381,16 +415,13 @@ def fill(lst, n=len(properties) - 1): ) -class Watershed(InstanceMethod): - """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - - def __init__(self, widget_parent=None): +class Watershed(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=WATERSHED, + name="Watershed", function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -420,23 +451,20 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( image, - self.sliders[0].slider_value, - self.sliders[1].slider_value, + self.sliders[0].value(), + self.sliders[1].value(), self.counters[0].value(), self.counters[1].value(), ) -class ConnectedComponents(InstanceMethod): - """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - - def __init__(self, widget_parent=None): +class ConnectedComponents(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=CONNECTED_COMP, + name="Connected Components", function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent, ) self.sliders[0].text_label.setText("Foreground probability threshold") @@ -454,56 +482,44 @@ def __init__(self, widget_parent=None): def run_method(self, image): return self.function( - image, self.sliders[0].slider_value, self.counters[0].value() + image, self.sliders[0].value(), self.counters[0].value() ) -class VoronoiOtsu(InstanceMethod): - """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - - def __init__(self, widget_parent=None): +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): + def __init__(self): super().__init__( - name=VORONOI_OTSU, + name="Voronoi-Otsu", function=voronoi_otsu, num_sliders=0, - num_counters=2, - widget_parent=widget_parent, + num_counters=3, ) - self.counters[0].label.setText("Spot sigma") # closeness + self.counters[0].label.setText("Spot sigma") self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") # smoothness + self.counters[1].label.setText("Outline sigma") self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" self.counters[1].setMaximum(100) self.counters[1].setValue(2) - # self.counters[2].label.setText("Small object removal") - # self.counters[2].tooltips = ( - # "Volume/size threshold for small object removal." - # "\nAll objects with a volume/size below this value will be removed." - # ) - # self.counters[2].setValue(30) + self.counters[2].label.setText("Small object removal") + self.counters[2].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) def run_method(self, image): - ################ - # For debugging - # import napari - # view = napari.Viewer() - # view.add_image(image) - # napari.run() - ################ - return self.function( image, self.counters[0].value(), self.counters[1].value(), - # self.counters[2].value(), + self.counters[2].value(), ) @@ -518,70 +534,67 @@ def __init__(self, parent=None): Args: parent: parent widget - """ super().__init__(parent) + self.method_choice = ui.DropdownMenu( - list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) + INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = {} - """Contains the instance of the method, with its name as key""" + self.methods = [] self.instance_widgets = {} - """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() def _build(self): + group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - try: - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - method_class = method(widget_parent=self.parent()) - self.methods[name] = method_class - self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets ? - if len(method_class.sliders) > 0: - for slider in method_class.sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method_class.counters) > 0: - for counter in method_class.counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + self.instance_widgets[name] = [] + if len(method().sliders) > 0: + for slider in method().sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method().counters) > 0: + for counter in method().counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): - if name != self.method_choice.currentText(): - for widget in self.instance_widgets[name]: + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() + + for widget in self.instance_widgets[method.name]: + widget.set_visibility(True) + + for key in self.instance_widgets.keys(): + if key != method.name: + for widget in self.instance_widgets[key]: widget.set_visibility(False) - else: - for widget in self.instance_widgets[name]: - widget.set_visibility(True) def run_method(self, volume): """ Calls instance function with chosen parameters - Args: volume: image data to run method on Returns: processed image from self._method - """ - method = self.methods[self.method_choice.currentText()] + method = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ]() return method.run_method(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { - VORONOI_OTSU: VoronoiOtsu, - WATERSHED: Watershed, - CONNECTED_COMP: ConnectedComponents, + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, } diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 30d37bbd..14449854 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -49,12 +49,8 @@ from tqdm import tqdm # local -from napari_cellseg3d import config, utils -from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_instance_seg import ( - ImageStats, - volume_stats, -) +from napari_cellseg3d.code_models.model_instance_seg import ImageStats +from napari_cellseg3d.code_models.model_instance_seg import volume_stats logger = utils.LOGGER @@ -448,10 +444,11 @@ def model_output( ): inputs = inputs.to("cpu") - # def model_output(inputs): - # return post_process_transforms( - # self.config.model_info.get_model().get_output(model, inputs) - # ) + model_output = lambda inputs: post_process_transforms( + self.config.model_info.get_model().get_output( + model, inputs + ) # TODO(cyril) refactor those functions + ) def model_output(inputs): return post_process_transforms( @@ -600,8 +597,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6c8370c1..6c0bc936 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -8,12 +8,10 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - InstanceWidgets, - clear_small_objects, - threshold, - to_semantic, -) +from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import threshold +from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -345,19 +343,18 @@ def _start(self): show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 302a52c9..f9cac5f3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -13,6 +13,11 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -552,12 +557,9 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[ - self.instance_widgets.method_choice.currentText() - ], - ) + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] self.post_process_config = config.PostProcessConfig( zoom=zoom_config, diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 6df82043..a9e3b44f 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,9 +7,6 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -17,6 +14,12 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -114,17 +117,11 @@ class Zoom: zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: InstanceMethod = None - - @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None ################ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 276f9214..8f4f2cdd 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -6,6 +6,10 @@ import napari # Qt +from qtpy import QtCore +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt + # from qtpy.QtCore import QtWarningMsg from qtpy import QtCore from qtpy.QtCore import QObject, Qt, QUrl @@ -1046,11 +1050,11 @@ def __init__( self.label = make_label(name=label) self.valueChanged.connect(self._update_step) - def _update_step(self): # FIXME check divide_factor + def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.01) - else: self.setSingleStep(0.1) + else: + self.setSingleStep(1) @property def tooltips(self): diff --git a/requirements.txt b/requirements.txt index 3189e9c4..9c7126eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pre-commit -pyclesperanto-prototype>=0.22.0 -pysqlite3 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From 6a78fb026c946cdc3f220626f456b88f6ec30425 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 350/577] isort --- .../_tests/test_weight_download.py | 7 ++- .../code_models/model_instance_seg.py | 4 +- .../code_plugins/plugin_convert.py | 2 +- .../code_plugins/plugin_model_inference.py | 8 ++- napari_cellseg3d/config.py | 11 ++-- napari_cellseg3d/interface.py | 53 +++++++++---------- 6 files changed, 40 insertions(+), 45 deletions(-) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index d8886a56..bffe422b 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,7 +1,6 @@ -from napari_cellseg3d.code_models.model_workers import ( - WEIGHTS_DIR, - WeightsDownloader, -) +from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import WeightsDownloader + # DISABLED, causes GitHub actions to freeze diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 9e6877da..376cf56f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,11 +4,11 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget +from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.filters import thresholding from skimage.transform import resize # from skimage.measure import mesh_surface_area @@ -17,8 +17,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6c0bc936..e4d7480b 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -9,9 +9,9 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index f9cac5f3..1da36989 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,15 +9,13 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index a9e3b44f..957946da 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,6 +7,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -14,12 +19,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 8f4f2cdd..d3cd4e84 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -7,35 +7,34 @@ # Qt from qtpy import QtCore -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt # from qtpy.QtCore import QtWarningMsg -from qtpy import QtCore -from qtpy.QtCore import QObject, Qt, QUrl -from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor -from qtpy.QtWidgets import ( - QCheckBox, - QComboBox, - QDoubleSpinBox, - QFileDialog, - QGridLayout, - QGroupBox, - QHBoxLayout, - QLabel, - QLayout, - QLineEdit, - QMenu, - QPushButton, - QRadioButton, - QScrollArea, - QSizePolicy, - QSlider, - QSpinBox, - QTextEdit, - QVBoxLayout, - QWidget, -) +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt +from qtpy.QtCore import QUrl +from qtpy.QtGui import QCursor +from qtpy.QtGui import QDesktopServices +from qtpy.QtGui import QTextCursor +from qtpy.QtWidgets import QCheckBox +from qtpy.QtWidgets import QComboBox +from qtpy.QtWidgets import QDoubleSpinBox +from qtpy.QtWidgets import QFileDialog +from qtpy.QtWidgets import QGridLayout +from qtpy.QtWidgets import QGroupBox +from qtpy.QtWidgets import QHBoxLayout +from qtpy.QtWidgets import QLabel +from qtpy.QtWidgets import QLayout +from qtpy.QtWidgets import QLineEdit +from qtpy.QtWidgets import QMenu +from qtpy.QtWidgets import QPushButton +from qtpy.QtWidgets import QRadioButton +from qtpy.QtWidgets import QScrollArea +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QSlider +from qtpy.QtWidgets import QSpinBox +from qtpy.QtWidgets import QTextEdit +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # Local from napari_cellseg3d import utils From 82cfccdf34032fd0f8b7ca53866a78e7555afbab Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 351/577] Fix inference --- .../code_models/model_instance_seg.py | 32 +- napari_cellseg3d/code_models/model_workers.py | 8 +- .../code_plugins/plugin_model_inference.py | 11 +- napari_cellseg3d/config.py | 6 +- notebooks/assess_instance.ipynb | 670 +----------------- 5 files changed, 46 insertions(+), 681 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 376cf56f..c0d246b1 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -540,8 +540,10 @@ def __init__(self, parent=None): self.method_choice = ui.DropdownMenu( INSTANCE_SEGMENTATION_METHOD_LIST.keys() ) - self.methods = [] + self.methods = {} + """Contains the instance of the method, with its name as key""" self.instance_widgets = {} + """Contains the lists of widgets for each methods, to show/hide""" self.method_choice.currentTextChanged.connect(self._set_visibility) self._build() @@ -551,17 +553,23 @@ def _build(self): group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) - for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): - self.instance_widgets[name] = [] - if len(method().sliders) > 0: - for slider in method().sliders: - group.layout.addWidget(slider.container) - self.instance_widgets[name].append(slider) - if len(method().counters) > 0: - for counter in method().counters: - group.layout.addWidget(counter.label) - group.layout.addWidget(counter) - self.instance_widgets[name].append(counter) + try: + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class + self.instance_widgets[name] = [] + # moderately unsafe way to init those widgets + if len(method_class.sliders) > 0: + for slider in method_class.sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method_class.counters) > 0: + for counter in method_class.counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + except RuntimeError as e: + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 14449854..6003b0ae 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -542,9 +542,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes( - instance_labels, 0, 2 - ) # TODO(cyril) check if correct + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -597,8 +595,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 1da36989..ff173b43 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -555,9 +555,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -725,9 +726,7 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + method_name = self.worker_config.post_process_config.instance.method.name number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 957946da..84ba4215 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -115,12 +115,16 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() ################ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b2382c31..40412282 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,691 +4,47 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "tags": [] + "collapsed": true }, "outputs": [], "source": [ - "import napari\n", "import numpy as np\n", - "from pathlib import Path\n", "from tifffile import imread\n", - "\n", - "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", - "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import (\n", - " binary_connected,\n", - " binary_watershed,\n", - " voronoi_otsu,\n", - ")" + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [] - }, + "execution_count": null, "outputs": [], - "source": [ - "viewer = napari.Viewer()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "source": [], "metadata": { "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "pycharm": { + "name": "#%%\n" } - ], - "source": [ - "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", - "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", - "\n", - "prediction = imread(prediction_path)\n", - "gt_labels = imread(gt_labels_path)\n", - "\n", - "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", - "gt_labels_resized = resize(gt_labels, zoom)\n", - "\n", - "\n", - "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "0.5817600487210719" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from napari_cellseg3d.utils import dice_coeff\n", - "\n", - "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", - "\n", - "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n", - "125\n" - ] - } - ], - "source": [ - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)\n", - "print(np.unique(gt_labels_resized).shape[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "connected = binary_connected(prediction_resized, thres_small=2)\n", - "viewer.add_labels(connected, name=\"connected\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, connected)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(68,\n", - " 43,\n", - " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, connected)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", - ")\n", - "viewer.add_labels(watershed)\n", - "eval.evaluate_model_performance(gt_labels_resized, watershed)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(25, 64, 64)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "\n", - "from skimage.morphology import remove_small_objects\n", - "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "dtype('int64')" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gt_labels_resized.dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(voronoi, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(gt_labels_resized, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] + } } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" + "pygments_lexer": "ipython2", + "version": "2.7.6" } }, "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat_minor": 0 +} \ No newline at end of file From 16193a15604042e7a84ba8d9778549eaa4ceae64 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 352/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/model_instance_seg.py | 6 +- .../dev_scripts/artefact_labeling.py | 94 ++-- .../dev_scripts/correct_labels.py | 99 ++--- .../dev_scripts/evaluate_labels.py | 409 ++++-------------- notebooks/assess_instance.ipynb | 401 ++++++++++++++++- 5 files changed, 551 insertions(+), 458 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index c0d246b1..cd101b35 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -165,14 +165,15 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Args: volume (np.ndarray): volume to segment spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - remove_small_size (float): remove all objects smaller than the specified size in pixels Returns: Instance segmentation labels from Voronoi-Otsu method + """ semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( @@ -534,6 +535,7 @@ def __init__(self, parent=None): Args: parent: parent widget + """ super().__init__(parent) @@ -590,10 +592,12 @@ def _set_visibility(self): def run_method(self, volume): """ Calls instance function with chosen parameters + Args: volume: image data to run method on Returns: processed image from self._method + """ method = INSTANCE_SEGMENTATION_METHOD_LIST[ self.method_choice.currentText() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 04e288d8..875ca9b6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,15 @@ +import numpy as np +from tifffile import imread +from tifffile import imwrite +from pathlib import Path +import scipy.ndimage as ndimage import os - import napari - # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed - -# import sys -# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from skimage.filters import threshold_otsu """ New code by Yves Paychere @@ -43,9 +44,7 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append( - np.array([i, unique[np.argmax(counts)]]) - ) + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -62,7 +61,7 @@ def map_labels(labels, artefacts): def make_labels( - image, + path_image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -74,8 +73,8 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - image : str - image array + path_image : str + Path to image. path_labels_out : str Path of the output labelled image. threshold_size : int, optional @@ -94,25 +93,21 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(image) + image = imread(path_image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( - np.max(image_contrasted) - np.min(image_contrasted) - ) + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size( - labels, min_size=threshold_size, is_labeled=True - ) + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -124,29 +119,26 @@ def make_labels( ) -def select_image_by_labels( - path_image, path_labels, path_image_out, label_values -): +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - image : np.array - image. - labels : np.array - labels. + path_image : str + Path to image. + path_labels : str + Path to labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - # image = imread(image) - # labels = imread(labels) - + image = imread(path_image) + labels = imread(path_labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smallest cube that contains all the non-zero pixels of a 3d image +# select the smalles cube that contains all the none zero pixel of an 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -164,15 +156,16 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(image, path_image_out): +def crop_image_path(path_image, path_image_out): """Crop image. Parameters ---------- - image : np.array - image + path_image : str + Path to image. path_image_out : str Path of the output image. """ + image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -220,9 +213,7 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile( - image[neurons], threshold_artefact_brightness_percent - ) + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -253,9 +244,7 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile( - sizes, threshold_artefact_size_percent - ) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -305,8 +294,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image, - labels, + image_path, + labels_path, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -315,10 +304,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image : np.array - image for artefact detection. - labels : np.array - label image array with each neurons labelled as a different int value. + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -328,6 +317,9 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ + image = imread(image_path) + labels = imread(labels_path) + artefacts = make_artefact_labels( image, labels, @@ -347,12 +339,11 @@ def visualize_images(paths): Parameters ---------- paths : list - List of images to visualize. + List of paths to images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - image = imread(path) - viewer.add_image(image) + viewer.add_image(imread(path), name=os.path.basename(path)) # wait for the user to close the viewer napari.run() @@ -379,12 +370,8 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [ - f for f in os.listdir(path + "/labels") if f.endswith(".tif") - ] - path_images = [ - f for f in os.listdir(path + "/volumes") if f.endswith(".tif") - ] + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] # sort the list path_labels.sort() path_images.sort() @@ -413,6 +400,7 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": + repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 77835007..f94327e2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,22 +1,19 @@ -import threading -import time -import warnings -from functools import partial -from pathlib import Path - -import napari import numpy as np +from tifffile import imread +from tifffile import imwrite import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings from napari.qt.threading import thread_worker -from tifffile import imread, imwrite from tqdm import tqdm import threading - # import sys # sys.path.append(str(Path(__file__) / "../../")) +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels - """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -36,9 +33,7 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm( - range(len(unique_label)), desc="relabeling", ncols=100 - ): + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): i = unique_label[i_label] if i == 0: continue @@ -86,16 +81,13 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact, test=False): +def ask_labels(unique_artefact): global returns returns = [] - if not test: - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] - else: - i_labels_to_add_tmp = [0] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] if i_labels_to_add_tmp == [0]: print("no label added") @@ -138,9 +130,7 @@ def ask_labels(unique_artefact, test=False): print("close the napari window to continue") -def relabel( - image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 -): +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -154,8 +144,6 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 - viewer : napari.Viewer, optional - the napari viewer, by default None """ global returns @@ -170,9 +158,7 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -192,45 +178,30 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread( - target=partial(ask_labels, test=test), args=(unique_artefact,) - ) + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add), 0, artefact - ) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - if not test: - napari.run() + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add_tmp), artefact, 0 - ) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) print("these labels will be added") - if test: - viewer.close() - viewer = napari.view_image(image) if viewer is None else viewer - if not test: - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") - if test: - revert = "n" - viewer.close() + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) - if test: - break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -287,16 +258,12 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget( - old_label, new_label, map_labels_existing, delay=0.5 -): +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect( - lambda arg: modify_viewer(old_label, new_label, arg) - ) + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -313,12 +280,8 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] - ) - new_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] - ) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -327,9 +290,7 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget( - old_label, new_label, map_labels_existing, delay=delay - ) + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) napari.run() @@ -346,12 +307,12 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, - str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) ) if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3eb62764..857bcd19 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,24 +1,74 @@ -import napari import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance( - labels, - model_labels, - threshold_correct=PERCENT_CORRECT, - print_details=False, - visualize=False, -): +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): """Evaluate the model performance. Parameters ---------- @@ -26,10 +76,8 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - print_details : bool + do_print : bool If True, print the results. - visualize : bool - If True, visualize the results. Returns ------- neuron_found : float @@ -53,7 +101,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels, threshold_correct + labels, model_labels ) # calculate the number of neurons individually found @@ -71,9 +119,7 @@ def evaluate_model_performance( artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean( - [i[3] for i in map_labels_existing] - ) + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -82,9 +128,7 @@ def evaluate_model_performance( if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean( - [i[2] for i in map_fused_neurons] - ) + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -99,42 +143,27 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info( - f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" - ) - - if print_details: - log.info(f"Neurons found: {neurons_found}") - log.info(f"Neurons fused: {neurons_fused}") - log.info(f"Neurons not found: {neurons_not_found}") - log.info(f"Artefacts found: {artefacts_found}") - log.info( - "Mean true positive ratio of the model: ", - ) - log.info(mean_true_positive_ratio_model) - log.info( + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, ) - log.info(mean_ratio_pixel_found) - log.info( + print( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, ) - log.info(mean_ratio_pixel_found_fused) - log.info( + print( "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, ) - log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact ) - log.info(mean_ratio_false_pixel_artefact) - if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -150,21 +179,15 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) is False, - unique_labels, - 0, + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where( - np.isin(labels, neurones_not_found_labels), labels, 0 - ) + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), - model_labels, - 0, + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -192,81 +215,6 @@ def evaluate_model_performance( ) -def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > threshold_correct: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > threshold_correct: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels - - def save_as_csv(results, path): """ Save the results as a csv file @@ -278,7 +226,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - log.debug(np.array(results).shape) + print(np.array(results).shape) df = pd.DataFrame( [results], columns=[ @@ -296,193 +244,6 @@ def save_as_csv(results, path): df.to_csv(path, index=False) -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From ba2062eda0159464edf7689aaac02d1676069143 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 353/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../code_models/model_instance_seg.py | 2 +- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 ++- .../dev_scripts/evaluate_labels.py | 282 ++++++++++++++++-- notebooks/assess_instance.ipynb | 239 +++++++++++---- 5 files changed, 494 insertions(+), 107 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index cd101b35..8b7e234b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -180,7 +180,7 @@ def voronoi_otsu( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) - return instance + return np.array(instance) def binary_connected( diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,14 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): + +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -21,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -28,15 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -44,14 +228,19 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,16 +248,22 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -78,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -119,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -128,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -144,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -179,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -226,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..86ef4e29 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -111,26 +106,25 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,112 - Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -178,7 +172,8 @@ ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected, name=\"connected\")\n", + "connected.dtype" ] }, { @@ -191,26 +186,25 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" ] }, { @@ -253,26 +247,25 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" ] }, { @@ -302,7 +295,9 @@ } ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] @@ -330,6 +325,10 @@ ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] @@ -337,6 +336,33 @@ { "cell_type": "code", "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -346,28 +372,94 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { @@ -375,12 +467,51 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + "\n" ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -389,7 +520,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { From 890e3b1f4d3efd2514321fa5c4c2eb6ba4fb6db9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 354/577] Added pre-commit hooks --- .pre-commit-config.yaml | 44 +++++++++++++++++++++++++++++------------ requirements.txt | 4 +++- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..802dfe20 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,44 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: +# - repo: https://github.com/pre-commit/pre-commit-hooks +# rev: v4.0.1 +# hooks: # - id: check-docstring-first - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] +# - id: end-of-file-fixer +# - id: trailing-whitespace +# - repo: https://github.com/asottile/setup-cfg-fmt +# rev: v1.20.0 +# hooks: +# - id: setup-cfg-fmt +# - repo: https://github.com/PyCQA/flake8 +# rev: 4.0.1 +# hooks: +# - id: flake8 +# additional_dependencies: [flake8-typing-imports>=1.9.0] +# - repo: https://github.com/myint/autoflake +# rev: v1.4 +# hooks: +# - id: autoflake +# args: ["--in-place", "--remove-all-unused-imports"] +# - repo: https://github.com/PyCQA/isort +# rev: 5.10.1 +# hooks: +# - id: isort - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.262' + rev: 'v0.0.257' hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 22.3.0 hooks: - id: black - args: [--line-length=79] + args: [--line-length=88] +# - repo: https://github.com/asottile/pyupgrade +# rev: v2.29.1 +# hooks: +# - id: pyupgrade +# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks rev: v0.3.0 hooks: diff --git a/requirements.txt b/requirements.txt index 9c7126eb..3189e9c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pre-commit +pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 From a6ad19e2280148c1ca22c42084ca2f687838dc3c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:40:31 +0100 Subject: [PATCH 355/577] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 802dfe20..d1e22fb1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: rev: 22.3.0 hooks: - id: black - args: [--line-length=88] + args: [--line-length=79] # - repo: https://github.com/asottile/pyupgrade # rev: v2.29.1 # hooks: From c9d76f437fde4e5cb92e95debd434c4465d1e016 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:48:32 +0100 Subject: [PATCH 356/577] Update pyproject.toml --- pyproject.toml | 52 ++------------------------------------------------ 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c94c4ee4..83aa1ebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,57 +32,9 @@ dependencies = [ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" -[tool.setuptools] -include-package-data = true - -[tool.setuptools.packages.find] -where = ["."] - -[tool.setuptools.package-data] -"*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] - [tool.ruff] -select = [ - "E", "F", "W", - "A", - "B", - "G", - "I", - "PT", - "PTH", - "RET", - "SIM", - "TCH", - "NPY", -] -# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) -# and 'G004' (do not use f-strings in logging) -ignore = ["E501", "E741", "G004", "A003"] -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".mypy_cache", - ".nox", - ".pants.d", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "venv", - "docs/conf.py", - "napari_cellseg3d/_tests/conftest.py", -] +# Never enforce `E501` (line length violations). +ignore = ["E501"] [tool.black] line-length = 79 From 3a9548b82dfb12778aadca0b45e1efbf0b7e4630 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:50:33 +0100 Subject: [PATCH 357/577] Update pyproject.toml Ruff config --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 83aa1ebb..462263e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ build-backend = "setuptools.build_meta" [tool.ruff] # Never enforce `E501` (line length violations). -ignore = ["E501"] +ignore = ["E501", "E741"] [tool.black] line-length = 79 From 3f3829f2b4907dce00adb9f503e2076b44173dcb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 358/577] Enfore pre-commit style --- .gitignore | 5 +- .../code_models/model_instance_seg.py | 19 +- napari_cellseg3d/code_models/model_workers.py | 9 +- .../code_plugins/plugin_convert.py | 5 +- .../code_plugins/plugin_model_inference.py | 3 - napari_cellseg3d/config.py | 7 +- .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 471 ++++++++---------- notebooks/assess_instance.ipynb | 158 +++--- 10 files changed, 313 insertions(+), 366 deletions(-) diff --git a/.gitignore b/.gitignore index f8547d92..427603f1 100644 --- a/.gitignore +++ b/.gitignore @@ -104,7 +104,4 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png -*.prof - - -*.prof +notebooks/instance_test.ipynb diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 8b7e234b..45a20b3d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -4,12 +4,10 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.transform import resize # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -570,23 +568,16 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() - - for widget in self.instance_widgets[method.name]: - widget.set_visibility(True) - - for key in self.instance_widgets.keys(): - if key != method.name: - for widget in self.instance_widgets[key]: + for name in self.instance_widgets.keys(): + if name != self.method_choice.currentText(): + for widget in self.instance_widgets[name]: widget.set_visibility(False) def run_method(self, volume): diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 6003b0ae..5456c730 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -444,11 +444,10 @@ def model_output( ): inputs = inputs.to("cpu") - model_output = lambda inputs: post_process_transforms( - self.config.model_info.get_model().get_output( - model, inputs - ) # TODO(cyril) refactor those functions - ) + # def model_output(inputs): + # return post_process_transforms( + # self.config.model_info.get_model().get_output(model, inputs) + # ) def model_output(inputs): return post_process_transforms( diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index e4d7480b..3346d2b8 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -4,7 +4,8 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread +from tifffile import imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils @@ -143,7 +144,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + self.results_path.mkdir(exist_ok=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index ff173b43..971f81bd 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -9,9 +9,6 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 84ba4215..3ae070e2 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -7,10 +7,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -89,7 +86,9 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info("Model list :\n" + str(f"{name}\n" for name in MODEL_LIST)) + logger.info( + "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) + ) return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..3c5be52a 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -8,261 +8,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -344,15 +98,21 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") log.info( "Mean true positive ratio of the model: ", ) @@ -389,7 +149,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -460,6 +220,193 @@ def save_as_csv(results, path): df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons + # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 86ef4e29..609da8b3 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -50,12 +50,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -84,9 +86,36 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -110,28 +139,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "125\n" ] }, { @@ -162,7 +172,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -171,9 +181,8 @@ } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized, thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { @@ -190,21 +199,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -217,14 +229,14 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", " 1.0)" ] }, @@ -251,21 +263,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -278,15 +293,15 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 8, @@ -471,21 +486,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -498,15 +516,15 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 13, From fc67737469d667f3bc075e59f33278c4446b8bd1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 359/577] Update .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 427603f1..ee1bf4a0 100644 --- a/.gitignore +++ b/.gitignore @@ -104,4 +104,4 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png -notebooks/instance_test.ipynb + From ab78666da4b4b964738023f275e6ee013f29cb6c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 360/577] Version bump --- napari_cellseg3d/__init__.py | 2 +- .../code_plugins/plugin_helper.py | 2 +- setup.cfg | 32 +++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 11e8de0e..2c537225 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc6" +__version__ = "0.0.2rc2" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index f8ac18ef..a20a2c61 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -37,7 +37,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc6'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 6111ed7e..17ef734e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,35 @@ +[metadata] +name = napari-cellseg3d +version = 0.0.2rc2 +author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis +author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu + +license = MIT +description = plugin for cell segmentation +long_description = file: README.md +long_description_content_type = text/markdown +classifiers = + Development Status :: 2 - Pre-Alpha + Intended Audience :: Science/Research + Framework :: napari + Topic :: Software Development :: Testing + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Operating System :: OS Independent + License :: OSI Approved :: MIT License + Topic :: Scientific/Engineering :: Artificial Intelligence + Topic :: Scientific/Engineering :: Image Processing + Topic :: Scientific/Engineering :: Visualization + +url = https://github.com/AdaptiveMotorControlLab/CellSeg3d +project_urls = + Bug Tracker = https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues + Documentation = https://adaptivemotorcontrollab.github.io/cellseg3d-docs/res/welcome.html + Source Code = https://github.com/AdaptiveMotorControlLab/CellSeg3d + [options] packages = find: include_package_data = True From 830a25e500e5409de070c21d5f11ae509015eb87 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:33:40 +0200 Subject: [PATCH 361/577] Revert "Version bump" This reverts commit 6e39971b39fb926084f3ed71d82e8c25f68f8b6f. --- napari_cellseg3d/__init__.py | 2 +- napari_cellseg3d/code_plugins/plugin_helper.py | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 2c537225..6e2681e8 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc2" +__version__ = "0.0.2rc1" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index a20a2c61..a3fd8c0d 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -37,7 +37,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 17ef734e..5789f74f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc2 +version = 0.0.2rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 51f3f0f2c9e46a9bdb5b21bba4a39f5ccc2465c1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 362/577] Updated project files --- pyproject.toml | 22 +++++++++++++--------- setup.cfg | 8 +++++++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 462263e7..5dec250c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,15 +9,17 @@ authors = [ requires-python = ">=3.8" dependencies = [ "numpy", - "napari>=0.4.14", + "napari[all]>=0.4.14", "QtPy", "opencv-python>=4.5.5", + "dask-image>=0.6.0", "scikit-image>=0.19.2", "matplotlib>=3.4.1", "tifffile>=2022.2.9", "imageio-ffmpeg>=0.4.5", "torch>=1.11", "monai[nibabel,einops]>=0.9.0", + "itk", "tqdm", "nibabel", "scikit-image", @@ -32,6 +34,15 @@ dependencies = [ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] + +[tool.setuptools.package-data] +"*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] + [tool.ruff] # Never enforce `E501` (line length violations). ignore = ["E501", "E741"] @@ -44,16 +55,10 @@ profile = "black" line_length = 79 [project.optional-dependencies] -all = [ - "napari[all]>=0.4.14", -] dev = [ "isort", "black", "ruff", - "tuna", - "pre-commit", - ] docs = [ "sphinx", @@ -67,5 +72,4 @@ test = [ "coverage", "tox", "twine", -] - +] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 5789f74f..2420dd1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc1 +version = 0.0.2rc6 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu @@ -65,6 +65,12 @@ install_requires = [options.packages.find] where = . +[options.package_data] +napari-cellseg3d = + res/*.png + code_models/models/pretrained/*.json + napari.yaml + [options.entry_points] napari.manifest = napari-cellseg3d = napari_cellseg3d:napari.yaml From ea0ddaf0d255d5ba3204edb5f004d7b95c8d921c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 363/577] Fixed wrong value in instance sliders --- .../code_models/model_instance_seg.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 45a20b3d..5eb987f6 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -174,6 +174,9 @@ def voronoi_otsu( """ semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -450,8 +453,8 @@ def __init__(self): def run_method(self, image): return self.function( image, - self.sliders[0].value(), - self.sliders[1].value(), + self.sliders[0].slider_value, + self.sliders[1].slider_value, self.counters[0].value(), self.counters[1].value(), ) @@ -481,7 +484,7 @@ def __init__(self): def run_method(self, image): return self.function( - image, self.sliders[0].value(), self.counters[0].value() + image, self.sliders[0].slider_value, self.counters[0].value() ) @@ -538,7 +541,7 @@ def __init__(self, parent=None): super().__init__(parent) self.method_choice = ui.DropdownMenu( - INSTANCE_SEGMENTATION_METHOD_LIST.keys() + list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) self.methods = {} """Contains the instance of the method, with its name as key""" @@ -558,7 +561,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -568,8 +571,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() @@ -590,9 +595,7 @@ def run_method(self, volume): Returns: processed image from self._method """ - method = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ]() + method = self.methods[self.method_choice.currentText()] return method.run_method(volume) From 845fbc909d787f66659ccf4aa60779a82fcd7d75 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 364/577] Removing dask-image --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ee1bf4a0..0ec12b01 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ notebooks/full_plot.html *.csv *.png +*.prof From bd77860e30e939c5abafbee8310add7318c6c95f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 365/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 0abcf387..24f4e867 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,12 +1,9 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np -from napari_cellseg3d.code_plugins.plugin_utilities import ( - UTILITIES_WIDGETS, - Utilities, -) +from napari_cellseg3d.code_plugins.plugin_utilities import Utilities +from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS def test_utils_plugin(make_napari_viewer): @@ -24,9 +21,4 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) - if utils_name == "Convert to instance labels": - # to avoid issues with Voronoi-Otsu missing runtime - menu = widget.utils_widgets[i].instance_widgets.method_choice - menu.setCurrentIndex(menu.currentIndex() + 1) - widget.utils_widgets[i]._start() From 0ffa37773b7605b8fb1250ec39137805ac318874 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 366/577] Relabeling tests --- .gitignore | 6 +- .../_tests/test_labels_correction.py | 3 +- .../dev_scripts/artefact_labeling.py | 93 +++++++++---------- .../dev_scripts/correct_labels.py | 75 ++++++++++----- 4 files changed, 102 insertions(+), 75 deletions(-) diff --git a/.gitignore b/.gitignore index 0ec12b01..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -104,5 +104,9 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png - *.prof + +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..9d4e7801 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,7 +1,6 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 9a344545..bf724a46 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,7 +1,5 @@ import numpy as np -from tifffile import imread -from tifffile import imwrite -from pathlib import Path +from tifffile import imwrite, imread import scipy.ndimage as ndimage import os import napari @@ -64,7 +62,7 @@ def map_labels(labels, artefacts): def make_labels( - path_image, + image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -76,7 +74,7 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - path_image : str + image : str Path to image. path_labels_out : str Path of the output labelled image. @@ -96,7 +94,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(path_image) + # image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -126,28 +124,26 @@ def make_labels( ) -def select_image_by_labels( - path_image, path_labels, path_image_out, label_values -): +def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - path_image : str - Path to image. - path_labels : str - Path to labels. + image : np.array + image. + labels : np.array + labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - image = imread(path_image) - labels = imread(path_labels) + # image = imread(image) + # labels = imread(labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smalles cube that contains all the none zero pixel of an 3d image +# select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -165,16 +161,15 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(path_image, path_image_out): +def crop_image_path(image, path_image_out): """Crop image. Parameters ---------- - path_image : str - Path to image. + image : np.array + image path_image_out : str Path of the output image. """ - image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -307,8 +302,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image_path, - labels_path, + image, + labels, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -317,10 +312,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image_path : str - Path to image file. - labels_path : str - Path to label image file with each neurons labelled as a different value. + image : np.array + image for artefact detection. + labels : np.array + label image array with each neurons labelled as a different int value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -330,9 +325,6 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ - image = imread(image_path) - labels = imread(labels_path) - artefacts = make_artefact_labels( image, labels, @@ -352,11 +344,12 @@ def visualize_images(paths): Parameters ---------- paths : list - List of paths to images to visualize. + List of images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - viewer.add_image(imread(path), name=os.path.basename(path)) + image = imread(path) + viewer.add_image(image) # wait for the user to close the viewer napari.run() @@ -416,22 +409,22 @@ def create_artefact_labels_from_folder( ) -if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] - print(f"REPO PATH : {repo_path}") - paths = [ - "dataset_clean/cropped_visual/train", - "dataset_clean/cropped_visual/val", - "dataset_clean/somatomotor", - "dataset_clean/visual_tif", - ] - for data_path in paths: - path = str(repo_path / data_path) - print(path) - create_artefact_labels_from_folder( - path, - do_visualize=False, - threshold_artefact_brightness_percent=20, - threshold_artefact_size_percent=1, - contrast_power=20, - ) +# if __name__ == "__main__": +# repo_path = Path(__file__).resolve().parents[1] +# print(f"REPO PATH : {repo_path}") +# paths = [ +# "dataset_clean/cropped_visual/train", +# "dataset_clean/cropped_visual/val", +# "dataset_clean/somatomotor", +# "dataset_clean/visual_tif", +# ] +# for data_path in paths: +# path = str(repo_path / data_path) +# print(path) +# create_artefact_labels_from_folder( +# path, +# do_visualize=False, +# threshold_artefact_brightness_percent=20, +# threshold_artefact_size_percent=1, +# contrast_power=20, +# ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index cd09754e..50f2e47a 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -4,6 +4,7 @@ import scipy.ndimage as ndimage import napari from pathlib import Path +from functools import partial import time import warnings from napari.qt.threading import thread_worker @@ -85,13 +86,16 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact): +def ask_labels(unique_artefact, test=False): global returns returns = [] - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + if not test: + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + else: + i_labels_to_add_tmp = [0] if i_labels_to_add_tmp == [0]: print("no label added") @@ -135,7 +139,13 @@ def ask_labels(unique_artefact): def relabel( - image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 + image_path, + label_path, + go_fast=False, + check_for_unicity=True, + delay=0.3, + viewer=None, + test=False, ): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters @@ -150,6 +160,8 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 + viewer : napari.Viewer, optional + the napari viewer, by default None """ global returns @@ -164,9 +176,10 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + if not test: + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -186,15 +199,22 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t = threading.Thread( + target=partial(ask_labels, test=test), args=(unique_artefact,) + ) t.start() artefact_copy = np.where( np.isin(artefact, i_labels_to_add), 0, artefact ) - viewer = napari.view_image(image) + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + viewer.add_image(image, name="image") viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - napari.run() + if not test: + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones @@ -205,15 +225,26 @@ def relabel( np.isin(artefact, i_labels_to_add_tmp), artefact, 0 ) print("these labels will be added") - viewer = napari.view_image(image) - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") + if test: + viewer.close() + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + if not test: + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if test: + revert = "n" + viewer.close() if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) + if test: + break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -334,9 +365,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") - image_path = str(im_path / "image.tif") - gt_labels_path = str(im_path / "labels.tif") - - relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# if __name__ == "__main__": +# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") +# image_path = str(im_path / "image.tif") +# gt_labels_path = str(im_path / "labels.tif") +# +# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) From 023a5db1b5e00d878c2bf454a2a43821b829d2c0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:06:43 +0200 Subject: [PATCH 367/577] Added new pre-commit hooks --- .pre-commit-config.yaml | 43 ++++++++++++----------------------------- pyproject.toml | 3 ++- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1e22fb1..da16a3b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,44 +1,25 @@ repos: -# - repo: https://github.com/pre-commit/pre-commit-hooks -# rev: v4.0.1 -# hooks: -# - id: check-docstring-first -# - id: end-of-file-fixer -# - id: trailing-whitespace -# - repo: https://github.com/asottile/setup-cfg-fmt -# rev: v1.20.0 -# hooks: -# - id: setup-cfg-fmt -# - repo: https://github.com/PyCQA/flake8 -# rev: 4.0.1 -# hooks: -# - id: flake8 -# additional_dependencies: [flake8-typing-imports>=1.9.0] -# - repo: https://github.com/myint/autoflake -# rev: v1.4 -# hooks: -# - id: autoflake -# args: ["--in-place", "--remove-all-unused-imports"] -# - repo: https://github.com/PyCQA/isort -# rev: 5.10.1 -# hooks: -# - id: isort + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-docstring-first + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.257' + rev: 'v0.0.262' hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.3.0 hooks: - id: black args: [--line-length=79] -# - repo: https://github.com/asottile/pyupgrade -# rev: v2.29.1 -# hooks: -# - id: pyupgrade -# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks rev: v0.3.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 5dec250c..d2a2adbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "isort", "black", "ruff", + "pre-commit", ] docs = [ "sphinx", @@ -72,4 +73,4 @@ test = [ "coverage", "tox", "twine", -] \ No newline at end of file +] From 682e4efdf74345658cb1fc619a8154427c3aa41b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 368/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da16a3b9..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,13 +2,14 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: check-docstring-first +# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From 29f1469f6a56e4cff230e7465b17cf67a7dc3761 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 369/577] Run full suite of pre-commit hooks --- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/model_instance_seg.py | 10 ++++++++- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 3 +-- 6 files changed, 34 insertions(+), 20 deletions(-) diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 24f4e867..7908e8b4 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 5eb987f6..e33d1d0f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -15,7 +15,7 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -517,6 +517,14 @@ def __init__(self): ) def run_method(self, image): + ################ + # For debugging + # import napari + # view = napari.Viewer() + # view.add_image(image) + # napari.run() + ################ + return self.function( image, self.counters[0].value(), diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3c5be52a..26b45d3f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,10 +1,9 @@ +import napari import numpy as np from collections import Counter from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict -import napari from napari_cellseg3d.utils import LOGGER as log From 4be5f4224b41b3bdd994077dab9d34b067867e7a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 24 Mar 2023 17:08:44 +0100 Subject: [PATCH 370/577] Model class refactor --- docs/res/guides/custom_model_template.rst | 24 -- .../_tests/test_weight_download.py | 4 +- napari_cellseg3d/code_models/model_workers.py | 367 ++++++++++-------- .../code_models/models/model_SegResNet.py | 48 ++- .../code_models/models/model_SwinUNetR.py | 36 +- .../code_models/models/model_TRAILMAP.py | 39 +- .../code_models/models/model_TRAILMAP_MS.py | 27 +- .../code_models/models/model_VNet.py | 56 +-- .../code_models/models/model_test.py | 24 +- .../code_plugins/plugin_model_inference.py | 141 ++++--- .../code_plugins/plugin_model_training.py | 4 +- .../code_plugins/plugin_review.py | 2 +- napari_cellseg3d/config.py | 65 +++- napari_cellseg3d/interface.py | 18 +- napari_cellseg3d/utils.py | 18 +- notebooks/assess_instance.ipynb | 121 +++--- requirements.txt | 6 +- setup.cfg | 2 +- 18 files changed, 568 insertions(+), 434 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index afbcd98a..9bad49b0 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -10,28 +10,4 @@ To add a custom model, you will need a **.py** file with the following structure :: - def get_net(): - return ModelClass # should return the class of the model, - # for example SegResNet or UNET - - def get_weights_file(): - return "weights_file.pth" # name of the weights file for the model, - # which should be in *napari_cellseg3d/models/pretrained* - - - def get_output(model, input): - out = model(input) # should return the model's output as [C, N, D,H,W] - # (C: channel, N, batch size, D,H,W : depth, height, width) - return out - - - def get_validation(model, val_inputs): - val_outputs = model(val_inputs) # should return the proper type for validation - # with sliding_window_inference from MONAI - return val_outputs - - - def ModelClass(x1,x2...): - # your Pytorch model here... - return results # should return as [C, N, D,H,W] diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index bffe422b..1bcb40d7 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR from napari_cellseg3d.code_models.model_workers import WeightsDownloader @@ -7,6 +7,6 @@ def test_weight_download(): downloader = WeightsDownloader() downloader.download_weights("test", "test.pth") - result_path = WEIGHTS_DIR / "test.pth" + result_path = PRETRAINED_WEIGHTS_DIR / "test.pth" assert result_path.is_file() diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 5456c730..f5bb798c 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from math import ceil from pathlib import Path -from typing import List, Optional +import typing as t import numpy as np import torch @@ -41,7 +41,10 @@ from monai.utils import set_determinism # threads -from napari.qt.threading import GeneratorWorker, WorkerBaseSignals +from napari.qt.threading import GeneratorWorker + +# from napari.qt.threading import thread_worker +from napari.qt.threading import WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -64,14 +67,16 @@ # https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ # https://napari-staging-site.github.io/guides/stable/threading.html -WEIGHTS_DIR = Path(__file__).parent.resolve() / Path("models/pretrained") -logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {WEIGHTS_DIR}") +PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( + "models/pretrained" +) +logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") class WeightsDownloader: """A utility class the downloads the weights of a model when needed.""" - def __init__(self, log_widget: Optional[ui.Log] = None): + def __init__(self, log_widget: t.Optional[ui.Log] = None): """ Creates a WeightsDownloader, optionally with a log widget to display the progress. @@ -93,11 +98,11 @@ def download_weights(self, model_name: str, model_weights_filename: str): import tarfile import urllib.request - def show_progress(count, block_size, total_size): + def show_progress(_, block_size, __): # count, block_size, total_size pbar.update(block_size) logger.info("*" * 20) - pretrained_folder_path = WEIGHTS_DIR + pretrained_folder_path = PRETRAINED_WEIGHTS_DIR json_path = pretrained_folder_path / Path("pretrained_model_urls.json") check_path = pretrained_folder_path / Path(model_weights_filename) @@ -167,12 +172,17 @@ def safe_extract( class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_""" # TODO link ? + Separate from Worker instances as indicated `here`_ + + .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + """ # TODO link ? log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" warn_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" + error_signal = Signal(Exception, str) + """qtpy.QtCore.Signal: signal to be sent when some error should be emitted in main thread""" # Should not be an instance variable but a class variable, not defined in __init__, see # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect @@ -203,33 +213,24 @@ def __init__( ): """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function. - Args: - * config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * device: cuda or cpu device to use for torch - - * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance + The config contains the following attributes: + * device: cuda or cpu device to use for torch + * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance + * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom + * results_path: the path to save the results to + * filetype: the file extension to use when saving, + * transforms: a dict containing transforms to perform at various times. + * instance: a dict containing parameters regarding instance segmentation + * use_window: use window inference with specific size or whole image + * window_infer_size: size of window if use_window is True + * keep_on_cpu: keep images on CPU or no + * stats_csv: compute stats on cells and save them to a csv file + * images_filepaths: the paths to the images of the dataset + * layer: the layer to run inference on - * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom - - * results_path: the path to save the results to - - * filetype: the file extension to use when saving, - - * transforms: a dict containing transforms to perform at various times. - - * instance: a dict containing parameters regarding instance segmentation - - * use_window: use window inference with specific size or whole image - - * window_infer_size: size of window if use_window is True - - * keep_on_cpu: keep images on CPU or no - - * stats_csv: compute stats on cells and save them to a csv file - - * images_filepaths: the paths to the images of the dataset + Args: + * worker_config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * layer: the layer to run inference on Note: See :py:func:`~self.inference` """ @@ -237,6 +238,7 @@ def __init__( self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self.config = worker_config @@ -269,6 +271,21 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Raises an error in main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + + self.log_signal.emit("!" * 20) + self.log_signal.emit("Error occured") + # self.log_signal.emit(msg) + # self.log_signal.emit(str(exception)) + + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + yield exception + # self.quit() + def log_parameters(self): config = self.config @@ -398,7 +415,7 @@ def load_layer(self): ) # for anisotropy to be monai-like, i.e. zyx # FIXME rotation not always correct dims_check = volume.shape - # self.log("\nChecking dimensions...") + self.log("Checking dimensions...") pad = utils.get_padding_dim(dims_check) # logger.debug(volume.shape) @@ -449,54 +466,61 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - def model_output(inputs): - return post_process_transforms( - self.config.model_info.get_model().get_output(model, inputs) - ) - - dataset_device = ( - "cpu" if self.config.keep_on_cpu else self.config.device - ) - - window_size = self.config.sliding_window_config.window_size - window_overlap = self.config.sliding_window_config.window_overlap - - # FIXME - # import sys - - # old_stdout = sys.stdout - # old_stderr = sys.stderr - - # sys.stdout = self.downloader.log_widget - # sys.stdout = self.downloader.log_widget - - outputs = sliding_window_inference( - inputs, - roi_size=[window_size, window_size, window_size], - sw_batch_size=1, # TODO add param - predictor=model_output, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) - - # sys.stdout = old_stdout - # sys.stderr = old_stderr - - out = outputs.detach().cpu() - - if aniso_transform is not None: - out = aniso_transform(out) + if self.config.keep_on_cpu: + dataset_device = "cpu" + else: + dataset_device = self.config.device - if post_process: - out = np.array(out).astype(np.float32) - out = np.squeeze(out) - return out + if self.config.sliding_window_config.is_enabled(): + window_size = self.config.sliding_window_config.window_size + window_size = [window_size, window_size, window_size] + window_overlap = self.config.sliding_window_config.window_overlap else: - return out + window_size = None + window_overlap = 0 + try: + # logger.debug(f"model : {model}") + logger.debug(f"inputs shape : {inputs.shape}") + logger.debug(f"inputs type : {inputs.dtype}") + try: + # outputs = model(inputs) + + def model_output_wrapper(inputs): + result = model(inputs) + return post_process_transforms(result) + + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) + except Exception as e: + logger.error(e, exc_info=True) + logger.debug("failed to run sliding window inference") + self.raise_error(e, "Error during sliding window inference") + logger.debug(f"Inference output shape: {outputs.shape}") + self.log("Post-processing...") + out = outputs.detach().cpu().numpy() + if aniso_transform is not None: + out = aniso_transform(out) + if post_process: + out = np.array(out).astype(np.float32) + out = np.squeeze(out) + return out + else: + return out + except Exception as e: + logger.error(e, exc_info=True) + self.raise_error(e, "Error during sliding window inference") + # sys.stdout = old_stdout + # sys.stderr = old_stderr - def create_result_dict( # FIXME replace with result class + def create_inference_result( self, semantic_labels, instance_labels, @@ -570,7 +594,10 @@ def save_image( + f"_{time}_" + self.config.filetype ) - imwrite(file_path, image) + try: + imwrite(file_path, image) + except ValueError as e: + self.raise_error(e, "Error during image saving") filename = Path(file_path).stem if from_layer: @@ -635,7 +662,7 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): self.log(f"Inference completed on image n°{i+1}") - return self.create_result_dict( + return self.create_inference_result( out, instance_labels, from_layer=False, @@ -646,9 +673,7 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): def stats_csv(self, instance_labels): if self.config.compute_stats: - stats = volume_stats( - instance_labels - ) # TODO test with area mesh function + stats = volume_stats(instance_labels) return stats # except ValueError as e: @@ -674,13 +699,14 @@ def inference_on_layer(self, image, model, post_process_transforms): instance_labels, stats = self.get_instance_result(out, from_layer=True) - return self.create_result_dict( + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels, from_layer=True, stats=stats, ) + # @thread_worker(connect={"errored": self.raise_error}) def inference(self): """ Requires: @@ -723,35 +749,68 @@ def inference(self): try: dims = self.config.model_info.model_input_size - # self.log(f"MODEL DIMS : {dims}") + self.log(f"MODEL DIMS : {dims}") model_name = self.config.model_info.name model_class = self.config.model_info.get_model() - self.log(model_name) + self.log(f"Model name : {model_name}") weights_config = self.config.weights_config post_process_config = self.config.post_process_config - if model_name == "SegResNet": - model = model_class.get_net( - input_image_size=[ - dims, - dims, - dims, - ], # TODO FIX ! find a better way & remove model-specific code + # try: + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[128, 128, 128], + ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, ) - elif model_name == "SwinUNetR": - model = model_class.get_net( - img_size=[dims, dims, dims], - use_checkpoint=False, + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - else: - model = model_class.get_net() - model = model.to(self.config.device) + model.load_state_dict( + torch.load( + weights, + map_location=self.config.device, + ) + ) + self.log("Done") + # except Exception as e: + # self.raise_error(e, "Issue loading weights") + # except Exception as e: + # self.raise_error(e, "Issue instantiating model") + + # if model_name == "SegResNet": + # model = model_class( + # input_image_size=[ + # dims, + # dims, + # dims, + # ], + # ) + # elif model_name == "SwinUNetR": + # model = model_class( + # img_size=[dims, dims, dims], + # use_checkpoint=False, + # ) + # else: + # model = model_class.get_net() self.log_parameters() - model.to(self.config.device) - # load_transforms = Compose( # [ # LoadImaged(keys=["image"]), @@ -772,25 +831,6 @@ def inference(self): AsDiscrete(threshold=t), EnsureType() ) - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path - else: - self.downloader.download_weights( - model_name, - model_class.get_weights_file(), - ) - weights = str( - WEIGHTS_DIR / Path(model_class.get_weights_file()) - ) - model.load_state_dict( - torch.load( - weights, - map_location=self.config.device, - ) - ) - self.log("Done") - is_folder = self.config.images_filepaths is not None is_layer = self.config.layer is not None @@ -815,6 +855,9 @@ def inference(self): else: raise ValueError("No data has been provided. Aborting.") + if model is None: + raise ValueError("Model is None") + model.eval() with torch.no_grad(): ################################ @@ -830,9 +873,10 @@ def inference(self): input_image, model, post_process_transforms ) model.to("cpu") - + # self.quit() except Exception as e: - self.log(f"Error during inference : {e}") + logger.error(e, exc_info=True) + self.raise_error(e, "Inference failed") self.quit() finally: self.quit() @@ -842,10 +886,10 @@ def inference(self): class TrainingReport: show_plot: bool = True epoch: int = 0 - loss_values: List = None - validation_metric: List = None + loss_values: t.Dict = None # TODO(cyril) : change to dict and unpack different losses for e.g. WNet with several losses + validation_metric: t.List = None weights: np.array = None - images: List[np.array] = None + images: t.List[np.array] = None class TrainingWorker(GeneratorWorker): @@ -897,6 +941,7 @@ def __init__( self._signals = LogSignal() self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self._weight_error = False ############################################# @@ -922,6 +967,14 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Sends an error to main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + self.quit() + def log_parameters(self): self.log("-" * 20) self.log("Parameters summary :\n") @@ -1051,23 +1104,14 @@ def train(self): do_sampling = self.config.sampling - if model_name == "SegResNet": - size = self.config.sample_size if do_sampling else check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - input_image_size=utils.get_padding_dim(size), - # out_channels=1, - # dropout_prob=0.3, - ) - elif model_name == "SwinUNetR": - size = self.sample_size if do_sampling else check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - img_size=utils.get_padding_dim(size), - use_checkpoint=True, - ) + if do_sampling: + size = self.config.sample_size else: - model = model_class.get_net() # get an instance of the model + size = check + + model = model_class( # FIXME check if correct + input_img_size=utils.get_padding_dim(size), use_checkpoint=True + ) model = model.to(self.config.device) epoch_loss_values = [] @@ -1207,7 +1251,11 @@ def train(self): else: load_whole_images = Compose( [ - LoadImaged(keys=["image", "label"]), + LoadImaged( + keys=["image", "label"], + # image_only=True, + # reader=WSIReader(backend="tifffile") + ), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="PLI"), SpatialPadd( @@ -1254,9 +1302,9 @@ def train(self): if weights_config.custom: if weights_config.use_pretrained: - weights_file = model_class.get_weights_file() + weights_file = model_class.weights_file self.downloader.download_weights(model_name, weights_file) - weights = WEIGHTS_DIR / Path(weights_file) + weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) weights_config.path = weights else: weights = str(Path(weights_config.path)) @@ -1270,6 +1318,7 @@ def train(self): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") + logger.error(e, exc_info=True) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" @@ -1317,7 +1366,7 @@ def train(self): batch_data["label"].to(device), ) optimizer.zero_grad() - outputs = model_class.get_output(model, inputs) + outputs = model(inputs) # self.log(f"Output dimensions : {outputs.shape}") loss = self.config.loss_function(outputs, labels) loss.backward() @@ -1348,10 +1397,24 @@ def train(self): val_data["image"].to(device), val_data["label"].to(device), ) - - val_outputs = model_class.get_validation( - model, val_inputs + self.log("Performing validation...") + try: + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) + except Exception as e: + self.raise_error(e, "Error during validation") + logger.debug( + f"val_outputs shape : {val_outputs.shape}" ) + # val_outputs = model(val_inputs) pred = decollate_batch(val_outputs) @@ -1398,7 +1461,7 @@ def train(self): weights=model.state_dict(), images=checkpoint_output, ) - + self.log("Validation completed") yield train_report weights_filename = ( @@ -1431,7 +1494,7 @@ def train(self): model.to("cpu") except Exception as e: - self.log(f"Error in training : {e}") + self.raise_error(e, "Error in training") self.quit() finally: self.quit() diff --git a/napari_cellseg3d/code_models/models/model_SegResNet.py b/napari_cellseg3d/code_models/models/model_SegResNet.py index 8856e18d..8b6e6e65 100644 --- a/napari_cellseg3d/code_models/models/model_SegResNet.py +++ b/napari_cellseg3d/code_models/models/model_SegResNet.py @@ -1,21 +1,33 @@ from monai.networks.nets import SegResNetVAE -def get_net(input_image_size, out_channels=1, dropout_prob=0.3): - return SegResNetVAE( - input_image_size, out_channels=out_channels, dropout_prob=dropout_prob - ) - - -def get_weights_file(): - return "SegResNet.pth" - - -def get_output(model, input): - out = model(input)[0] - return out - - -def get_validation(model, val_inputs): - val_outputs = model(val_inputs) - return val_outputs[0] +class SegResNet_(SegResNetVAE): + use_default_training = True + weights_file = "SegResNet.pth" + + def __init__( + self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs + ): + super().__init__( + input_img_size, + out_channels=out_channels, + dropout_prob=dropout_prob, + ) + + def forward(self, x): + res = SegResNetVAE.forward(self, x) + # logger.debug(f"SegResNetVAE.forward: {res[0].shape}") + return res[0] + + def get_model_test(self, size): + return SegResNetVAE( + size, in_channels=1, out_channels=1, dropout_prob=0.3 + ) + + # def get_output(model, input): + # out = model(input)[0] + # return out + + # def get_validation(model, val_inputs): + # val_outputs = model(val_inputs) + # return val_outputs[0] diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 532aeb89..fe4d380c 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,25 +1,23 @@ -import torch from monai.networks.nets import SwinUNETR -def get_weights_file(): - return "Swin64_best_metric.pth" +class SwinUNETR_(SwinUNETR): + use_default_training = True + weights_file = "Swin64_best_metric.pth" + def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs + ) -def get_net(img_size, use_checkpoint=True): - return SwinUNETR( - img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - ) + # def get_output(self, input): + # out = self(input) + # return torch.sigmoid(out) - -def get_output(model, input): - out = model(input) - return torch.sigmoid(out) - - -def get_validation(model, val_inputs): - return model(val_inputs) + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 09de2a26..8a108e37 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -2,28 +2,8 @@ from torch import nn -def get_weights_file(): - # model additionally trained on Mathis/Wyss mesoSPIM data - return "TRAILMAP_PyTorch.pth" - # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them - - -def get_net(): - return TRAILMAP(1, 1) - - -def get_output(model, input): - out = model(input) - - return out - - -def get_validation(model, val_inputs): - return model(val_inputs) - - class TRAILMAP(nn.Module): - def __init__(self, in_ch, out_ch): + def __init__(self, in_ch, out_ch, *args, **kwargs): super().__init__() self.conv0 = self.encoderBlock(in_ch, 32, 3) # input self.conv1 = self.encoderBlock(32, 64, 3) # l1 @@ -112,3 +92,20 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) return out + + +class TRAILMAP_(TRAILMAP): + use_default_training = True + weights_file = "TRAILMAP_PyTorch.pth" # model additionally trained on Mathis/Wyss mesoSPIM data + # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + # def get_output(model, input): + # out = model(input) + # + # return out + + # def get_validation(model, val_inputs): + # return model(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 0fc68d34..e3ca00a6 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,20 +1,21 @@ from napari_cellseg3d.code_models.models.unet.model import UNet3D -def get_weights_file(): - # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) - return "TRAILMAP_MS_best_metric_epoch_26.pth" - - -def get_net(): - return UNet3D(1, 1) +class TRAILMAP_MS_(UNet3D): + use_default_training = True + weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth" + # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) -def get_output(model, input): - out = model(input) - - return out + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + # def get_output(self, input): + # out = self(input) -def get_validation(model, val_inputs): - return model(val_inputs) + # return out + # + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 0c854832..41554e80 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -1,29 +1,33 @@ -from monai.inferers import sliding_window_inference from monai.networks.nets import VNet -def get_net(): - return VNet() - - -def get_weights_file(): - return "VNet_40e.pth" - - -def get_output(model, input): - out = model(input) - return out - - -def get_validation(model, val_inputs): - roi_size = (64, 64, 64) - sw_batch_size = 1 - val_outputs = sliding_window_inference( - val_inputs, - roi_size, - sw_batch_size, - model, - mode="gaussian", - overlap=0.7, - ) - return val_outputs +class VNet_(VNet): + use_default_training = True + weights_file = "VNet_40e.pth" + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError: + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) + + # def get_output(self, input): + # out = self(input) + # return out + + # def get_validation(self, val_inputs): # FIXME standardize + # roi_size = (64, 64, 64) + # sw_batch_size = 1 + # val_outputs = sliding_window_inference( + # val_inputs, + # roi_size, + # sw_batch_size, + # self, + # # mode="gaussian", + # # overlap=0.7, + # ) + # return val_outputs diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 5871c4a7..1ccac3da 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -2,26 +2,22 @@ from torch import nn -def get_weights_file(): - return "test.pth" - - class TestModel(nn.Module): - def __init__(self): + use_default_training = True + weights_file = "test.pth" + + def __init__(self, **kwargs): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) - def get_net(self): - return self - - def get_output(self, _, input): - return input + # def get_output(self, _, input): + # return input - def get_validation(self, val_inputs): - return val_inputs + # def get_validation(self, val_inputs): + # return val_inputs # if __name__ == "__main__": @@ -29,8 +25,8 @@ def get_validation(self, val_inputs): # model = TestModel() # model.train() # model.zero_grad() -# from napari_cellseg3d.config import WEIGHTS_DIR +# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR # torch.save( # model.state_dict(), -# WEIGHTS_DIR + f"/{get_weights_file()}" +# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" # ) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 971f81bd..ab61b590 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -159,6 +159,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" ) + self.window_size_choice.setCurrentIndex(3) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -601,10 +602,13 @@ def start(self): self.worker.set_download_log(self.log) self.worker.started.connect(self.on_start) + self.worker.log_signal.connect(self.log.print_and_log) self.worker.warn_signal.connect(self.log.warn) + self.worker.error_signal.connect(self.log.error) + self.worker.yielded.connect(partial(self.on_yield)) # - self.worker.errored.connect(partial(self.on_yield)) + self.worker.errored.connect(partial(self.on_error)) self.worker.finished.connect(self.on_finish) if self.get_device(show=False) == "cuda": @@ -641,15 +645,18 @@ def on_start(self): self.log.print_and_log(f"Saving results to : {self.results_path}") self.log.print_and_log("Worker is running...") - def on_error(self): - """Catches errors and tries to clean up. TODO : upgrade""" + def on_error(self, error): + """Catches errors and tries to clean up.""" + self.log.print_and_log("!" * 20) self.log.print_and_log("Worker errored...") - self.log.print_and_log("Trying to clean up...") + self.log.error(error) + # self.log.print_and_log("Trying to clean up...") + self.worker.quit() self.btn_start.setText("Start") self.btn_close.setVisible(True) - self.worker = None self.worker_config = None + self.worker = None self.empty_cuda_cache() def on_finish(self): @@ -672,83 +679,91 @@ def on_yield(self, result: InferenceResult): data (dict): dict yielded by :py:func:`~inference()`, contains : "image_id" : index of the returned image, "original" : original volume used for inference, "result" : inference result widget (QWidget): widget for accessing attributes """ + + if isinstance(result, Exception): + self.on_error(result) + # raise result # viewer, progress, show_res, show_res_number, zoon, show_original # check that viewer checkbox is on and that max number of displays has not been reached. # widget.log.print_and_log(result) + try: + image_id = result.image_id + model_name = result.model_name + if self.worker_config.images_filepaths is not None: + total = len(self.worker_config.images_filepaths) + else: + total = 1 - image_id = result.image_id - model_name = result.model_name - if self.worker_config.images_filepaths is not None: - total = len(self.worker_config.images_filepaths) - else: - total = 1 + viewer = self._viewer - viewer = self._viewer + pbar_value = image_id // total + if pbar_value == 0: + pbar_value = 1 - pbar_value = image_id // total - if pbar_value == 0: - pbar_value = 1 + self.progress.setValue(100 * pbar_value) - self.progress.setValue(100 * pbar_value) + if ( + self.config.show_results + and image_id <= self.config.show_results_count + ): + zoom = self.worker_config.post_process_config.zoom.zoom_values - if ( - self.config.show_results - and image_id <= self.config.show_results_count - ): - zoom = self.worker_config.post_process_config.zoom.zoom_values + viewer.dims.ndisplay = 3 + viewer.scale_bar.visible = True - viewer.dims.ndisplay = 3 - viewer.scale_bar.visible = True + if self.config.show_original and result.original is not None: + viewer.add_image( + result.original, + colormap="inferno", + name=f"original_{image_id}", + scale=zoom, + opacity=0.7, + ) + + out_colormap = "twilight" + if self.worker_config.post_process_config.thresholding.enabled: + out_colormap = "turbo" - if self.config.show_original and result.original is not None: viewer.add_image( - result.original, - colormap="inferno", - name=f"original_{image_id}", - scale=zoom, - opacity=0.7, + result.result, + colormap=out_colormap, + name=f"pred_{image_id}_{model_name}", + opacity=0.8, ) - out_colormap = "twilight" - if self.worker_config.post_process_config.thresholding.enabled: - out_colormap = "turbo" - - viewer.add_image( - result.result, - colormap=out_colormap, - name=f"pred_{image_id}_{model_name}", - opacity=0.8, - ) - - if result.instance_labels is not None: - labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + if result.instance_labels is not None: + labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - stats = result.stats + stats = result.stats - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if self.worker_config.compute_stats and stats is not None: + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + self.log.print_and_log( + f"Number of instances : {stats.number_objects}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) + except Exception as e: + self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index cf8e4b85..4f2f7cdf 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -985,7 +985,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e) + logger.error(e, exc_info=True) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1153,7 +1153,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e) + logger.error(e, exc_info=True) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 7ed6c549..e3e05f6c 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -401,7 +401,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e) + logger.error(e, exc_info=True) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 3ae070e2..3d1d6d9e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -10,12 +10,11 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP -from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet -from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR -from napari_cellseg3d.code_models.models import ( - model_TRAILMAP_MS as TRAILMAP_MS, -) -from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ +from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ +from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ +from napari_cellseg3d.code_models.models.model_VNet import VNet_ + from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -24,16 +23,15 @@ # TODO(cyril) add JSON load/save MODEL_LIST = { - "SegResNet": SegResNet, - "VNet": VNet, + "SegResNet": SegResNet_, + "VNet": VNet_, # "TRAILMAP": TRAILMAP, - "TRAILMAP_MS": TRAILMAP_MS, - "SwinUNetR": SwinUNetR, + "TRAILMAP_MS": TRAILMAP_MS_, + "SwinUNetR": SwinUNETR_, # "test" : DO NOT USE, reserved for testing } - -WEIGHTS_DIR = str( +PRETRAINED_WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") ) @@ -69,8 +67,11 @@ class ReviewSession: @dataclass class ModelInfo: - """Dataclass recording model info : - - name (str): name of the model""" + """Dataclass recording model info + Args: + name (str): name of the model + model_input_size (Optional[List[int]]): input size of the model + """ name: str = next(iter(MODEL_LIST)) model_input_size: Optional[List[int]] = None @@ -94,7 +95,7 @@ def get_model_name_list(): @dataclass class WeightsInfo: - path: str = WEIGHTS_DIR + path: str = PRETRAINED_WEIGHTS_DIR custom: bool = False use_pretrained: Optional[bool] = False @@ -121,6 +122,14 @@ class InstanceSegConfig: @dataclass class PostProcessConfig: + """Class to record params for post processing + + Args: + zoom (Zoom): zoom config + thresholding (Thresholding): thresholding config + instance (InstanceSegConfig): instance segmentation config + """ + zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() instance: InstanceSegConfig = InstanceSegConfig() @@ -141,7 +150,15 @@ def is_enabled(self): @dataclass class InfererConfig: - """Class to record params for Inferer plugin""" + """Class to record params for Inferer plugin + + Args: + model_info (ModelInfo): model info + show_results (bool): show results in napari + show_results_count (int): number of results to show + show_original (bool): show original image in napari + anisotropy_resolution (List[int]): anisotropy resolution + """ model_info: ModelInfo = None show_results: bool = False @@ -152,7 +169,21 @@ class InfererConfig: @dataclass class InferenceWorkerConfig: - """Class to record configuration for Inference job""" + """Class to record configuration for Inference job + + Args: + device (str): device to use for inference + model_info (ModelInfo): model info + weights_config (WeightsInfo): weights info + results_path (str): path to save results + filetype (str): filetype to save results + keep_on_cpu (bool): keep results on cpu + compute_stats (bool): compute stats + post_process_config (PostProcessConfig): post processing config + sliding_window_config (SlidingWindowConfig): sliding window config + images_filepaths (str): path to images to infer + layer (napari.layers.Layer): napari layer to infer on + """ device: str = "cpu" model_info: ModelInfo = ModelInfo() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d3cd4e84..57b3b0bd 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -295,6 +295,22 @@ def warn(self, warning): finally: self.lock.release() + def error(self, error, msg=None): + """Show exception and message from another thread""" + self.lock.acquire() + try: + logger.error(error, exc_info=True) + if msg is not None: + self.print_and_log(f"{msg} : {error}", printing=False) + else: + self.print_and_log( + f"Excepetion caught in another thread : {error}", + printing=False, + ) + raise error + finally: + self.lock.release() + ############## # UI elements @@ -1199,7 +1215,7 @@ def open_folder_dialog( logger.info(f"Default : {default_path}") filenames = QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + widget, "Open directory", default_path + "/.." ) return filenames diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index ecb6a199..5683c541 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,10 +2,8 @@ import warnings from datetime import datetime from pathlib import Path - import numpy as np - -# from dask import delayed +from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -38,6 +36,18 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] +# class TiffFileReader(ImageReader): +# def __init__(self): +# super().__init__() +# +# def verify_suffix(self, filename): +# if filename == "tif": +# return True +# def read(self, data, **kwargs): +# return tfl_imread(data) +# +# def get_data(self, data): +# return data, {} def normalize_x(image): """Normalizes the values of an image array to be between [-1;1] rather than [0;255] @@ -122,8 +132,6 @@ def dice_coeff(y_true, y_pred): def resize(image, zoom_factors): - from monai.transforms import Zoom - isotropic_image = Zoom( zoom_factors, keep_size=False, diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 609da8b3..169775f5 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -49,10 +49,20 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -62,14 +72,15 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", + "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = resize(prediction, zoom)\n", + "prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -92,7 +103,7 @@ { "data": { "text/plain": [ - "0.5817600487210719" + "0.7538125057831502" ] }, "execution_count": 4, @@ -103,9 +114,15 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", + "semantic_gt = to_semantic(gt_labels_resized.copy())\n", + "semantic_pred = to_semantic(prediction_resized.copy())\n", + "\n", + "viewer.add_image(semantic_gt, colormap='bop blue')\n", + "viewer.add_image(semantic_pred, colormap='red')\n", + "\n", "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", + " semantic_gt,\n", + " prediction_resized\n", ")" ] }, @@ -172,7 +189,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -199,24 +216,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" + "2023-03-24 14:23:13,590 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", + "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", + "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" ] }, { @@ -229,15 +246,15 @@ { "data": { "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" + "(63,\n", + " 45,\n", + " 16,\n", + " 16,\n", + " 0.819027731148306,\n", + " 0.8401649108992161,\n", + " 0.83609908334452,\n", + " 0.8066092803671974,\n", + " 0.98)" ] }, "execution_count": 7, @@ -263,24 +280,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" + "2023-03-24 14:23:13,732 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", + "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", + "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -293,15 +310,15 @@ { "data": { "text/plain": [ - "(68,\n", - " 43,\n", + "(77,\n", + " 34,\n", " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" + " 9,\n", + " 0.728461197681457,\n", + " 0.8885669859686413,\n", + " 0.8950588507577087,\n", + " 0.7472814623489069,\n", + " 0.878614359974009)" ] }, "execution_count": 8, @@ -339,7 +356,7 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 0.6, outline_sigma=0.7)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", @@ -486,24 +503,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" + "2023-03-24 14:23:14,241 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", + "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", + "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" ] }, { @@ -516,15 +533,15 @@ { "data": { "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" + "(102,\n", + " 8,\n", + " 14,\n", + " 16,\n", + " 0.708505702558253,\n", + " 0.8832633585884945,\n", + " 0.9759871495093808,\n", + " 0.6670483272595948,\n", + " 0.8653680990771155)" ] }, "execution_count": 13, diff --git a/requirements.txt b/requirements.txt index 3189e9c4..3ca0e56d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ black coverage +imageio-ffmpeg>=0.4.5 isort itk pytest @@ -15,13 +16,12 @@ QtPy opencv-python>=4.5.5 pre-commit pyclesperanto-prototype>=0.22.0 -pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 +ruff tifffile>=2022.2.9 -imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,einops]>=1.0.1 +monai[nibabel,einops,tifffile]>=1.0.1 pillow scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index 2420dd1c..f3294b60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=1.0.1 + monai[nibabel,einops,tifffile]>=1.0.1 itk tqdm nibabel From 5f5978e7f3af662f2d3c2317a1573bbfaa29d723 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 29 Mar 2023 09:55:58 +0200 Subject: [PATCH 371/577] Added LR scheduler in training - Added ReduceLROnPlateau with params in training - Updated training guide - Minor UI attribute refactor - black --- docs/res/code/plugin_model_training.rst | 1 - docs/res/guides/training_module_guide.rst | 2 + napari_cellseg3d/_tests/fixtures.py | 3 + .../_tests/test_plugin_inference.py | 2 +- .../code_models/model_framework.py | 2 +- .../code_models/model_instance_seg.py | 8 +-- napari_cellseg3d/code_models/model_workers.py | 11 ++++ napari_cellseg3d/code_plugins/plugin_base.py | 2 +- .../code_plugins/plugin_convert.py | 4 +- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- .../code_plugins/plugin_model_inference.py | 4 +- .../code_plugins/plugin_model_training.py | 62 ++++++++++++------- .../code_plugins/plugin_utilities.py | 2 +- napari_cellseg3d/config.py | 2 + napari_cellseg3d/interface.py | 43 +++++++------ 15 files changed, 93 insertions(+), 57 deletions(-) diff --git a/docs/res/code/plugin_model_training.rst b/docs/res/code/plugin_model_training.rst index 870dfd14..dc1271fc 100644 --- a/docs/res/code/plugin_model_training.rst +++ b/docs/res/code/plugin_model_training.rst @@ -18,6 +18,5 @@ Methods Attributes ********************* - .. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer :members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index fb8992d2..05ce69be 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -74,6 +74,8 @@ The training module is comprised of several tabs. * The **batch size** (larger means quicker training and possibly better performance but increased memory usage) * The **number of epochs** (a possibility is to start with 60 epochs, and decrease or increase depending on performance.) * The **epoch interval** for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.) +* The **schedular patience**, which is the amount of epoch at a plateau that is waited for until the learning rate is reduced +* The **scheduler factor**, which is the factor by which to reduce the learning rate once a plateau is reached * Whether to use deterministic training, and the seed to use. .. note:: diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index b40a77d3..bd6b0ac7 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -14,3 +14,6 @@ def print_and_log(self, text, printing=None): def warn(self, warning): warnings.warn(warning) + + def error(self, e): + raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..66c50fba 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -38,4 +38,4 @@ def test_inference(make_napari_viewer, qtbot): # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: # blocker.connect(widget.worker.errored) - # assert len(viewer.layers) == 2 + #### assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 2cc4265e..d541b486 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -78,7 +78,7 @@ def __init__( # ) self.model_choice = ui.DropdownMenu( - sorted(self.available_models.keys()), label="Model name" + sorted(self.available_models.keys()), text_label="Model name" ) self.weights_filewidget = ui.FilePathWidget( diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index e33d1d0f..0c87a2df 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -73,7 +73,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label="", parent=None), + ui.DoubleIncrementCounter(text_label="", parent=None), ) self.counters.append(getattr(self, widget)) @@ -426,13 +426,13 @@ def __init__(self): num_counters=2, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" self.sliders[0].setValue(50) - self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].label.setText("Seed probability threshold") self.sliders[1].tooltips = "Probability threshold for seeding" self.sliders[1].setValue(90) @@ -469,7 +469,7 @@ def __init__(self): num_counters=1, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index f5bb798c..bca24035 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -70,6 +70,7 @@ PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) +VERBOSE_SCHEDULER = True logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") @@ -1292,6 +1293,13 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + mode="min", + factor=self.config.scheduler_factor, + patience=self.config.scheduler_patience, + verbose=VERBOSE_SCHEDULER, + ) dice_metric = DiceMetric(include_background=True, reduction="mean") best_metric = -1 @@ -1384,6 +1392,9 @@ def train(self): epoch_loss_values.append(epoch_loss) self.log(f"Epoch: {epoch + 1}, Average loss: {epoch_loss:.4f}") + self.log("Updating scheduler...") + scheduler.step(epoch_loss) + checkpoint_output = [] if ( diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 0a613ee7..2cb3581b 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -99,7 +99,7 @@ def __init__( ) self.filetype_choice = ui.DropdownMenu( - [".tif", ".tiff"], label="File format" + [".tif", ".tiff"], text_label="File format" ) ######## qInstallMessageHandler(ui.handle_adjust_errors_wrapper(self)) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 3346d2b8..a847ebf7 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -210,7 +210,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): lower=1, upper=100000, default=10, - label="Remove all smaller than (pxs):", + text_label="Remove all smaller than (pxs):", ) self.results_path = Path.home() / Path("cellseg3d/small_removed") @@ -472,7 +472,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): upper=100000.0, step=0.5, default=10.0, - label="Remove all smaller than (value):", + text_label="Remove all smaller than (value):", ) self.results_path = Path.home() / Path("cellseg3d/threshold") diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 9830d51e..1647e858 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -533,7 +533,7 @@ def set_slice( # container_widget.extend(sliders) ui.add_widgets( container_widget.layout, - [ui.combine_blocks(s, s.text_label) for s in sliders], + [ui.combine_blocks(s, s.label) for s in sliders], ) # vw.window.add_dock_widget([spinbox, container_widget], area="right") wdgts = vw.window.add_dock_widget( diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index ab61b590..44e70a76 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -105,7 +105,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ###################### # TODO : better way to handle SegResNet size reqs ? self.model_input_size = ui.IntIncrementCounter( - lower=1, upper=1024, default=128, label="\nModel input size" + lower=1, upper=1024, default=128, text_label="\nModel input size" ) self.model_choice.currentIndexChanged.connect( self._toggle_display_model_input_size @@ -157,7 +157,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): # ) self.window_size_choice = ui.DropdownMenu( - sizes_window, label="Window size" + sizes_window, text_label="Window size" ) self.window_size_choice.setCurrentIndex(3) # set to 64 by default diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 4f2f7cdf..132c9531 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -46,6 +46,8 @@ class Trainer(ModelFramework, metaclass=ui.QWidgetSingleton): Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during training through validation.""" + default_config = config.TrainingWorkerConfig() + def __init__( self, viewer: "napari.viewer.Viewer", @@ -168,14 +170,13 @@ def __init__( ################################ # interface - default = config.TrainingWorkerConfig() self.zip_choice = ui.CheckBox("Compress results") self.validation_percent_choice = ui.Slider( lower=10, upper=90, - default=default.validation_percent * 100, + default=self.default_config.validation_percent * 100, step=5, parent=self, ) @@ -183,12 +184,12 @@ def __init__( self.epoch_choice = ui.IntIncrementCounter( lower=2, upper=200, - default=default.max_epochs, - label="Number of epochs : ", + default=self.default_config.max_epochs, + text_label="Number of epochs : ", ) self.loss_choice = ui.DropdownMenu( - sorted(self.loss_dict.keys()), label="Loss function" + sorted(self.loss_dict.keys()), text_label="Loss function" ) self.lbl_loss_choice = self.loss_choice.label self.loss_choice.setCurrentIndex(0) @@ -196,7 +197,7 @@ def __init__( self.sample_choice_slider = ui.Slider( lower=2, upper=50, - default=default.num_samples, + default=self.default_config.num_samples, text_label="Number of patches per image : ", ) @@ -205,13 +206,13 @@ def __init__( self.batch_choice = ui.Slider( lower=1, upper=10, - default=default.batch_size, + default=self.default_config.batch_size, text_label="Batch size : ", ) self.val_interval_choice = ui.IntIncrementCounter( - default=default.validation_interval, - label="Validation interval : ", + default=self.default_config.validation_interval, + text_label="Validation interval : ", ) self.epoch_choice.valueChanged.connect(self._update_validation_choice) @@ -228,12 +229,24 @@ def __init__( ] self.learning_rate_choice = ui.DropdownMenu( - learning_rate_vals, label="Learning rate" + learning_rate_vals, text_label="Learning rate" ) self.lbl_learning_rate_choice = self.learning_rate_choice.label self.learning_rate_choice.setCurrentIndex(1) + self.scheduler_patience_choice = ui.IntIncrementCounter( + 1, + 99, + default=self.default_config.scheduler_patience, + text_label="Scheduler patience", + ) + self.scheduler_factor_choice = ui.Slider( + divide_factor=100, + default=self.default_config.scheduler_factor * 100, + text_label="Scheduler factor :", + ) + self.augment_choice = ui.CheckBox("Augment data") self.close_buttons = [ @@ -268,7 +281,8 @@ def __init__( "Deterministic training", func=self._toggle_deterministic_param ) self.box_seed = ui.IntIncrementCounter( - upper=10000000, default=default.deterministic_config.seed + upper=10000000, + default=self.default_config.deterministic_config.seed, ) self.lbl_seed = ui.make_label("Seed", self) self.container_seed = ui.combine_blocks( @@ -309,6 +323,12 @@ def set_tooltips(): self.learning_rate_choice.setToolTip( "The learning rate to use in the optimizer. \nUse a lower value if you're using pre-trained weights" ) + self.scheduler_factor_choice.setToolTip( + "The factor by which to reduce the learning rate once the loss reaches a plateau" + ) + self.scheduler_patience_choice.setToolTip( + "The amount of epochs to wait for before reducing the learning rate" + ) self.augment_choice.setToolTip( "Check this to enable data augmentation, which will randomly deform, flip and shift the intensity in images" " to provide a more general dataset. \nUse this if you're extracting more than 10 samples per image" @@ -632,26 +652,20 @@ def _build(self): "Training parameters", r=1, b=5, t=11 ) - spacing = 20 - ui.add_widgets( train_param_group_l, [ self.batch_choice.container, # batch size - ui.combine_blocks( - self.learning_rate_choice, - self.lbl_learning_rate_choice, - min_spacing=spacing, - horizontal=False, - l=5, - t=5, - r=5, - b=5, - ), # learning rate + self.lbl_learning_rate_choice, + self.learning_rate_choice, self.epoch_choice.label, # epochs self.epoch_choice, self.val_interval_choice.label, self.val_interval_choice, # validation interval + self.scheduler_patience_choice.label, + self.scheduler_patience_choice, + self.scheduler_factor_choice.label, + self.scheduler_factor_choice.container, ], None, ) @@ -833,6 +847,8 @@ def start(self): max_epochs=self.epoch_choice.value(), loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.value(), validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index fdcad6d3..45c0c119 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -43,7 +43,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): # self.small = RemoveSmallUtils(self._viewer) self.utils_choice = ui.DropdownMenu( - UTILITIES_WIDGETS.keys(), label="Utilities" + UTILITIES_WIDGETS.keys(), text_label="Utilities" ) self._build() diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 3d1d6d9e..afc16bd3 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -230,6 +230,8 @@ class TrainingWorkerConfig: max_epochs: int = 5 loss_function: callable = None learning_rate: np.float64 = 1e-3 + scheduler_patience: int = 10 + scheduler_factor: float = 0.5 validation_interval: int = 2 batch_size: int = 1 results_path_folder: str = str(Path.home() / Path("cellseg3d/training")) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 57b3b0bd..9a100dc2 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -408,21 +408,21 @@ def __init__( self, entries: Optional[list] = None, parent: Optional[QWidget] = None, - label: Optional[str] = None, + text_label: Optional[str] = None, fixed: Optional[bool] = True, ): """Args: entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None - label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well + text_label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True. """ super().__init__(parent) self.label = None if entries is not None: self.addItems(entries) - if label is not None: - self.label = QLabel(label) + if text_label is not None: + self.label = QLabel(text_label) if fixed: self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) @@ -473,9 +473,10 @@ def __init__( self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - self.text_label = None + self.label = None self.container = ContainerWidget( - # parent=self.parent + # parent=self.parent, + b=0, ) self._divide_factor = divide_factor @@ -498,7 +499,7 @@ def __init__( ) if text_label is not None: - self.text_label = make_label(text_label, parent=self) + self.label = make_label(text_label, parent=self) if default < lower: self._warn_outside_bounds(default) @@ -517,14 +518,14 @@ def __init__( def set_visibility(self, visible: bool): self.container.setVisible(visible) self.setVisible(visible) - self.text_label.setVisible(visible) + self.label.setVisible(visible) def _build_container(self): - if self.text_label is not None: + if self.label is not None: add_widgets( self.container.layout, [ - self.text_label, + self.label, combine_blocks(self._value_label, self, b=0), ], ) @@ -568,8 +569,8 @@ def tooltips(self, tooltip: str): self.setToolTip(tooltip) self._value_label.setToolTip(tooltip) - if self.text_label is not None: - self.text_label.setToolTip(tooltip) + if self.label is not None: + self.label.setToolTip(tooltip) @property def slider_value(self): @@ -739,7 +740,9 @@ def __init__( self.image = None self.layer_type = layer_type - self.layer_list = DropdownMenu(parent=self, label=name, fixed=False) + self.layer_list = DropdownMenu( + parent=self, text_label=name, fixed=False + ) # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? self._viewer.layers.events.inserted.connect(partial(self._add_layer)) @@ -1044,7 +1047,7 @@ def __init__( step: Optional[float] = 1.0, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[float]): minimum value, defaults to 0 @@ -1053,7 +1056,7 @@ def __init__( step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed - label (Optional[str]): if provided, creates a label with the chosen title to use with the counter + text_label (Optional[str]): if provided, creates a label with the chosen title to use with the counter """ super().__init__(parent) @@ -1061,8 +1064,8 @@ def __init__( self.layout = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) self.valueChanged.connect(self._update_step) def _update_step(self): @@ -1122,7 +1125,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[int]): minimum value, defaults to 0 @@ -1138,8 +1141,8 @@ def __init__( self.label = None self.container = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) @property def tooltips(self): From 65347c46c8e19d051e3e2caf53101db8b5f87148 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 31 Mar 2023 15:45:00 +0200 Subject: [PATCH 372/577] Update assess_instance.ipynb --- notebooks/assess_instance.ipynb | 162 ++++++++++++++++++++------------ 1 file changed, 101 insertions(+), 61 deletions(-) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 169775f5..3dae22a9 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -49,20 +49,10 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -72,15 +62,16 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "# prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "# prediction_resized = resize(prediction, zoom)\n", - "prediction_resized = prediction # for trailmap\n", + "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -103,7 +94,7 @@ { "data": { "text/plain": [ - "0.7538125057831502" + "0.8592223181276479" ] }, "execution_count": 4, @@ -189,7 +180,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -216,24 +207,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,590 - Mapping labels...\n" + "2023-03-31 15:37:19,775 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3699.66it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", - "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", - "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" + "2023-03-31 15:37:19,812 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,815 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-31 15:37:19,816 - Percent of fused neurons found: 36.80%\n", + "2023-03-31 15:37:19,817 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -246,15 +237,15 @@ { "data": { "text/plain": [ - "(63,\n", - " 45,\n", - " 16,\n", - " 16,\n", - " 0.819027731148306,\n", - " 0.8401649108992161,\n", - " 0.83609908334452,\n", - " 0.8066092803671974,\n", - " 0.98)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 7, @@ -280,24 +271,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,732 - Mapping labels...\n" + "2023-03-31 15:37:19,919 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3992.79it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", - "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", - "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" + "2023-03-31 15:37:19,949 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,952 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-31 15:37:19,953 - Percent of fused neurons found: 34.40%\n", + "2023-03-31 15:37:19,953 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -310,15 +301,15 @@ { "data": { "text/plain": [ - "(77,\n", - " 34,\n", + "(68,\n", + " 43,\n", " 13,\n", - " 9,\n", - " 0.728461197681457,\n", - " 0.8885669859686413,\n", - " 0.8950588507577087,\n", - " 0.7472814623489069,\n", - " 0.878614359974009)" + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 8, @@ -344,6 +335,40 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:37:21,076 - build program: kernel 'gaussian_blur_separable_3d' was part of a lengthy source build resulting from a binary cache miss (0.88 s)\n", + "2023-03-31 15:37:21,514 - build program: kernel 'copy_3d' was part of a lengthy source build resulting from a binary cache miss (0.42 s)\n", + "2023-03-31 15:37:22,021 - build program: kernel 'detect_maxima_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:22,642 - build program: kernel 'minimum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.59 s)\n", + "2023-03-31 15:37:23,117 - build program: kernel 'minimum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:23,651 - build program: kernel 'minimum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,188 - build program: kernel 'maximum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,801 - build program: kernel 'maximum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.60 s)\n", + "2023-03-31 15:37:25,263 - build program: kernel 'maximum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:25,766 - build program: kernel 'histogram_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:26,256 - build program: kernel 'sum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:26,699 - build program: kernel 'greater_constant_3d' was part of a lengthy source build resulting from a binary cache miss (0.43 s)\n", + "2023-03-31 15:37:27,158 - build program: kernel 'binary_and_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:27,635 - build program: kernel 'add_image_and_scalar_3d' was part of a lengthy source build resulting from a binary cache miss (0.47 s)\n", + "2023-03-31 15:37:28,128 - build program: kernel 'set_nonzero_pixels_to_pixelindex' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:28,580 - build program: kernel 'set_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:29,076 - build program: kernel 'nonzero_minimum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:29,551 - build program: kernel 'set_2d' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:30,035 - build program: kernel 'flag_existing_labels' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:30,544 - build program: kernel 'set_column_2d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:31,033 - build program: kernel 'sum_reduction_x' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:31,572 - build program: kernel 'block_enumerate' was part of a lengthy source build resulting from a binary cache miss (0.53 s)\n", + "2023-03-31 15:37:32,094 - build program: kernel 'replace_intensities' was part of a lengthy source build resulting from a binary cache miss (0.51 s)\n", + "2023-03-31 15:37:32,685 - build program: kernel 'add_images_weighted_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:33,256 - build program: kernel 'onlyzero_overwrite_maximum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.56 s)\n", + "2023-03-31 15:37:33,845 - build program: kernel 'onlyzero_overwrite_maximum_diamond_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:34,369 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:34,888 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n" + ] + }, { "data": { "text/plain": [ @@ -503,24 +528,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,241 - Mapping labels...\n" + "2023-03-31 15:37:36,854 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 611.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", - "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", - "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" + "2023-03-31 15:37:37,087 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:37,098 - Percent of non-fused neurons found: 87.20%\n", + "2023-03-31 15:37:37,104 - Percent of fused neurons found: 1.60%\n", + "2023-03-31 15:37:37,114 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -533,15 +558,15 @@ { "data": { "text/plain": [ - "(102,\n", + "(109,\n", + " 2,\n", + " 13,\n", " 8,\n", - " 14,\n", - " 16,\n", - " 0.708505702558253,\n", - " 0.8832633585884945,\n", - " 0.9759871495093808,\n", - " 0.6670483272595948,\n", - " 0.8653680990771155)" + " 0.8285521200005869,\n", + " 0.8809251900364068,\n", + " 0.9838709677419355,\n", + " 0.782258064516129,\n", + " 1.0)" ] }, "execution_count": 13, @@ -565,10 +590,25 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:40:34,683 - No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'\n" + ] + } + ], "source": [ "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -587,7 +627,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" } }, "nbformat": 4, From ee8af0ae4a270430d7369aeceab4b9e43e57cfbf Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 11:09:30 +0200 Subject: [PATCH 373/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df43b4fa..df67a187 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png +notebooks/instance_test.ipynb *.prof #include test data From 03cacc8c039cfbf820446ab79453a76cdfae359a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 14:27:21 +0200 Subject: [PATCH 374/577] Started adding WNet --- napari_cellseg3d/code_models/model_workers.py | 4 +- .../code_models/models/model_SwinUNetR.py | 29 +- .../code_models/models/model_TRAILMAP_MS.py | 15 +- .../code_models/models/model_WNet.py | 27 ++ .../pretrained/pretrained_model_urls.json | 1 + .../code_models/models/wnet/__init__.py | 0 .../code_models/models/wnet/crf.py | 112 ++++++ .../code_models/models/wnet/model.py | 189 ++++++++++ .../code_models/models/wnet/soft_Ncuts.py | 352 ++++++++++++++++++ napari_cellseg3d/config.py | 22 ++ 10 files changed, 739 insertions(+), 12 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/model_WNet.py create mode 100644 napari_cellseg3d/code_models/models/wnet/__init__.py create mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py create mode 100644 napari_cellseg3d/code_models/models/wnet/model.py create mode 100644 napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index bca24035..7a45c47e 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -761,7 +761,9 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=[128, 128, 128], + input_img_size=dims, + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) # try: model = model.to(self.config.device) diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index fe4d380c..f38409b8 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,7 @@ from monai.networks.nets import SwinUNETR +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER class SwinUNETR_(SwinUNETR): @@ -6,14 +9,24 @@ class SwinUNETR_(SwinUNETR): weights_file = "Swin64_best_metric.pth" def __init__(self, input_img_size, use_checkpoint=True, **kwargs): - super().__init__( - input_img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - **kwargs - ) + try: + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs, + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index e3ca00a6..1123173a 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,4 +1,7 @@ from napari_cellseg3d.code_models.models.unet.model import UNet3D +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER class TRAILMAP_MS_(UNet3D): @@ -8,9 +11,15 @@ class TRAILMAP_MS_(UNet3D): # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) def __init__(self, in_channels=1, out_channels=1, **kwargs): - super().__init__( - in_channels=in_channels, out_channels=out_channels, **kwargs - ) + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py new file mode 100644 index 00000000..63a91b10 --- /dev/null +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -0,0 +1,27 @@ +from napari_cellseg3d.code_models.models.wnet.model import WNet + + +class WNet_(WNet): + use_default_training = False + weights_file = "wnet.pth" + + def __init__( + self, + in_channels=1, + out_channels=1, + num_classes=2, + device="cpu", + **kwargs + ): + super().__init__( + device=device, + in_channels=in_channels, + out_channels=out_channels, + num_classes=num_classes, + ) + + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + # dec = self.forward_decoder(enc) + return enc diff --git a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json index cd0782fb..cde5e332 100644 --- a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json @@ -3,5 +3,6 @@ "SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet.tar.gz", "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet.tar.gz", "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz", + "WNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet.tar.gz", "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" } diff --git a/napari_cellseg3d/code_models/models/wnet/__init__.py b/napari_cellseg3d/code_models/models/wnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py new file mode 100644 index 00000000..ca11fba2 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -0,0 +1,112 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" + +import numpy as np +import pydensecrf.densecrf as dcrf +from pydensecrf.utils import ( + unary_from_softmax, + create_pairwise_gaussian, + create_pairwise_bilateral, +) + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py new file mode 100644 index 00000000..585ea0dd --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -0,0 +1,189 @@ +""" +Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. +The model performs unsupervised segmentation of 3D images. +""" + +import torch +import torch.nn as nn + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", +] + + +class WNet(nn.Module): + """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. + The model performs unsupervised segmentation of 3D images. + It first encodes the input image into a latent space using the U-Net UEncoder, then decodes it back to the original image using the U-Net UDecoder. + """ + + def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + super(WNet, self).__init__() + self.device = device + self.encoder = UNet(device, in_channels, num_classes, encoder=True) + self.decoder = UNet(device, num_classes, out_channels, encoder=False) + + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + dec = self.forward_decoder(enc) + return enc, dec + + def forward_encoder(self, x): + """Forward pass of the encoder part of the W-Net model.""" + enc = self.encoder(x) + return enc + + def forward_decoder(self, enc): + """Forward pass of the decoder part of the W-Net model.""" + dec = self.decoder(enc) + return dec + + +class UNet(nn.Module): + """Half of the W-Net model, based on the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels, encoder=True): + super(UNet, self).__init__() + self.device = device + self.in_b = InBlock(device, in_channels, 64) + self.conv1 = Block(device, 64, 128) + self.conv2 = Block(device, 128, 256) + self.conv3 = Block(device, 256, 512) + self.bot = Block(device, 512, 1024) + self.deconv1 = Block(device, 1024, 512) + self.deconv2 = Block(device, 512, 256) + self.deconv3 = Block(device, 256, 128) + self.out_b = OutBlock(device, 128, out_channels) + + self.sm = nn.Softmax(dim=1).to(device) + self.encoder = encoder + + def forward(self, x): + """Forward pass of the U-Net model.""" + in_b = self.in_b(x.to(self.device)) + c1 = self.conv1(nn.MaxPool3d(2)(in_b)) + c2 = self.conv2(nn.MaxPool3d(2)(c1)) + c3 = self.conv3(nn.MaxPool3d(2)(c2)) + x = self.bot(nn.MaxPool3d(2)(c3)) + x = self.deconv1( + torch.cat( + [ + c3, + nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.deconv2( + torch.cat( + [ + c2, + nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.deconv3( + torch.cat( + [ + c1, + nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + x = self.out_b( + torch.cat( + [ + in_b, + nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + )(x), + ], + dim=1, + ) + ) + if self.encoder: + x = self.sm(x) + return x + + +class InBlock(nn.Module): + """Input block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(InBlock, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the input block.""" + return self.module(x.to(self.device)) + + +class Block(nn.Module): + """Basic block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(Block, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), + nn.Conv3d(in_channels, out_channels, 1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), + nn.Conv3d(out_channels, out_channels, 1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(out_channels, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the basic block.""" + return self.module(x.to(self.device)) + + +class OutBlock(nn.Module): + """Output block of the U-Net architecture.""" + + def __init__(self, device, in_channels, out_channels): + super(OutBlock, self).__init__() + self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, 64, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(64, device=device), + nn.Conv3d(64, 64, 3, padding=1, device=device), + nn.ReLU(), + nn.Dropout(p=0.65), + nn.BatchNorm3d(64, device=device), + nn.Conv3d(64, out_channels, 1, device=device), + ).to(device) + + def forward(self, x): + """Forward pass of the output block.""" + return self.module(x.to(self.device)) diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py new file mode 100644 index 00000000..6a625355 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -0,0 +1,352 @@ +""" +Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. +The implementation was adapted and approximated to reduce computational and memory cost. +This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from scipy.stats import norm + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", + "Jianbo Shi", + "Jitendra Malik", + "Frank Odom", +] + + +class SoftNCutsLoss(nn.Module): + """Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. + + Args: + data_shape (H, W, D): shape of the images as a tuple. + o_i (scalar): scale of the gaussian kernel of pixels brightness. + o_x (scalar): scale of the gaussian kernel of pixels spacial distance. + radius (scalar): radius of pixels for which we compute the weights + """ + + def __init__(self, data_shape, device, o_i, o_x, radius=None): + super(SoftNCutsLoss, self).__init__() + self.o_i = o_i + self.o_x = o_x + self.radius = radius + self.H = data_shape[0] + self.W = data_shape[1] + self.D = data_shape[2] + self.device = device + + if self.radius is None: + self.radius = min( + max(5, math.ceil(min(self.H, self.W, self.D) / 20)), + self.H, + self.W, + self.D, + ) + + # self.distances, self.indexes = self.get_distances() + + """ + + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration + distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) + distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) + distances_D = torch.tensor(range(self.D)).expand(self.D, self.D) # (D, D) + + # Compute in cuda if possible + if torch.cuda.is_available(): + distances_H = distances_H.cuda() + distances_W = distances_W.cuda() + distances_D = distances_D.cuda() + + distances_H = torch.abs(torch.subtract(distances_H, distances_H.T)) # (H, H) + distances_W = torch.abs(torch.subtract(distances_W, distances_W.T)) # (W, W) + distances_D = torch.abs(torch.subtract(distances_D, distances_D.T)) # (D, D) + + distances_H = distances_H.view(self.H, 1, 1, self.H, 1, 1).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (H, 1, 1, H, 1, 1) -> (H, W, D, H, W, D) + distances_W = distances_W.view(1, self.W, 1, 1, self.W, 1).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (1, W, 1, 1, W, 1) -> (H, W, D, H, W, D) + distances_D = distances_D.view(1, 1, self.D, 1, 1, self.D).expand( + self.H, self.W, self.D, self.H, self.W, self.D + ).to_sparse() # (1, 1, D, 1, 1, D) -> (H, W, D, H, W, D) + + mask_H = torch.le(distances_H, self.radius).bool() # (H, W, D, H, W, D) + mask_W = torch.le(distances_W, self.radius).bool() # (H, W, D, H, W, D) + mask_D = torch.le(distances_D, self.radius).bool() # (H, W, D, H, W, D) + + distances_H = (distances_H * mask_H) # (H, W, D, H, W, D) + distances_W = (distances_W * mask_W) # (H, W, D, H, W, D) + distances_D = (distances_D * mask_D) # (H, W, D, H, W, D) + + mask_H =mask_H.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + mask_W =mask_W.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + mask_D =mask_D.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) + + distances_H = distances_H.pow(2) # (H, W, D, H, W, D) + distances_W = distances_W.pow(2) # (H, W, D, H, W, D) + distances_D = distances_D.pow(2) # (H, W, D, H, W, D) + + squared_distances = torch.add( + torch.add(distances_H, distances_W), + distances_D, + ) # (H, W, D, H, W, D) + + squared_distances = squared_distances.flatten(0, 2).flatten( + 1, 3 + ) # (H*W*D, H*W*D) + + # Mask to only keep the weights for the pixels in the radius + self.mask = torch.le(squared_distances, self.radius**2).bool() # (H*W*D, H*W*D) + + # Add all masks to get the final mask + self.mask = self.mask.logical_and(mask_H).logical_and(mask_W).logical_and(mask_D) # (H*W*D, H*W*D) + + W_X = torch.exp( + torch.neg(torch.div(squared_distances, self.o_x)) + ) # (H*W*D, H*W*D) + + self.W_X = torch.mul(W_X, self.mask) # (H*W*D, H*W*D) + """ + + def forward(self, labels, inputs): + """Forward pass of the Soft N-Cuts loss. + + Args: + labels (torch.Tensor): Tensor of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + The Soft N-Cuts loss of shape (N,). + """ + inputs.shape[0] + inputs.shape[1] + K = labels.shape[1] + + labels.to(self.device) + inputs.to(self.device) + + loss = 0 + + kernel = self.gaussian_kernel(self.radius, self.o_x).to(self.device) + + for k in range(K): + # Compute the average pixel value for this class, and the difference from each pixel + class_probs = labels[:, k].unsqueeze(1) + class_mean = torch.mean( + inputs * class_probs, dim=(2, 3, 4), keepdim=True + ) / torch.add( + torch.mean(class_probs, dim=(2, 3, 4), keepdim=True), 1e-5 + ) + diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1) + + # Weight the loss by the difference from the class average. + weights = torch.exp(diff.pow(2).mul(-1 / self.o_i**2)) + + numerator = torch.sum( + class_probs + * F.conv3d(class_probs * weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + denominator = torch.sum( + class_probs * F.conv3d(weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + loss += nn.L1Loss()( + numerator / torch.add(denominator, 1e-6), + torch.zeros_like(numerator), + ) + + return K - loss + + """ + for k in range(K): + Ak = labels[:, k, :, :, :] # (N, H, W, D) + flatted_Ak = Ak.view(N, -1) # (N, H*W*D) + + # Compute the numerator of the Soft N-Cuts loss for k + flatted_Ak_unsqueeze = flatted_Ak.unsqueeze(1) # (N, 1, H*W*D) + transposed_Ak = torch.transpose(flatted_Ak_unsqueeze, 1, 2) # (N, H*W*D, 1) + probs = torch.bmm(transposed_Ak, flatted_Ak_unsqueeze) # (N, H*W*D, H*W*D) + probs_unsqueeze_expanded = probs.unsqueeze(1) # (N, 1, H*W*D, H*W*D) + numerator_elements = torch.mul( + probs_unsqueeze_expanded, weights + ) # (N, C, H*W*D, H*W*D) + numerator = torch.sum(numerator_elements, dim=(2, 3)) # (N, C) + + # Compute the denominator of the Soft N-Cuts loss for k + expanded_flatted_Ak = flatted_Ak.expand( + -1, self.H * self.W * self.D + ) # (N, H*W*D, H*W*D) + e_f_Ak_unsqueeze_expanded = expanded_flatted_Ak.unsqueeze( + 1 + ) # (N, 1, H*W*D, H*W*D) + denominator_elements = torch.mul( + e_f_Ak_unsqueeze_expanded, weights + ) # (N, C, H*W*D, H*W*D) + denominator = torch.sum(denominator_elements, dim=(2, 3)) # (N, C) + + # Compute the Soft N-Cuts loss for k + division = torch.div(numerator, torch.add(denominator, 1e-8)) # (N, C) + loss = torch.sum(division, dim=1) # (N,) + losses.append(loss) + + loss = torch.sum(torch.stack(losses, dim=0), dim=0) # (N,) + + return torch.add(torch.neg(loss), K) + """ + + def gaussian_kernel(self, radius, sigma): + """Computes the Gaussian kernel. + + Args: + radius (int): The radius of the kernel. + sigma (float): The standard deviation of the Gaussian distribution. + + Returns: + The Gaussian kernel of shape (1, 1, 2*radius+1, 2*radius+1, 2*radius+1). + """ + x_2 = np.linspace(-radius, radius, 2 * radius + 1) ** 2 + dist = ( + np.sqrt( + x_2.reshape(-1, 1, 1) + + x_2.reshape(1, -1, 1) + + x_2.reshape(1, 1, -1) + ) + / sigma + ) + kernel = norm.pdf(dist) / norm.pdf(0) + kernel = torch.from_numpy(kernel.astype(np.float32)) + kernel = kernel.view( + (1, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]) + ) + + return kernel + + def get_distances(self): + """Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration. + + Returns: + distances (dict): for each pixel index, we get the distances to the pixels in a radius around it. + """ + + distances = dict() + indexes = np.array( + [ + (i, j, k) + for i in range(self.H) + for j in range(self.W) + for k in range(self.D) + ] + ) + + for i in indexes: + iTuple = (i[0], i[1], i[2]) + distances[iTuple] = dict() + + sliceD = indexes[ + i[0] * self.H + + i[1] * self.W + + max(0, i[2] - self.radius) : i[0] * self.H + + i[1] * self.W + + min(self.D, i[2] + self.radius) + ] + sliceW = indexes[ + i[0] * self.H + + max(0, i[1] - self.radius) * self.W + + i[2] : i[0] * self.H + + min(self.W, i[1] + self.radius) * self.W + + i[2] : self.D + ] + sliceH = indexes[ + max(0, i[0] - self.radius) * self.H + + i[1] * self.W + + i[2] : min(self.H, i[0] + self.radius) * self.H + + i[1] * self.W + + i[2] : self.D * self.W + ] + + for j in np.concatenate((sliceD, sliceW, sliceH)): + jTuple = (j[0], j[1], j[2]) + distance = np.linalg.norm(i - j) + if distance > self.radius: + continue + distance = math.exp(-(distance**2) / (self.o_x**2)) + + if jTuple not in distances: + distances[iTuple][jTuple] = distance + + return distances, indexes + + def get_weights(self, inputs): + """Computes the weights matrix for the Soft N-Cuts loss. + + Args: + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + list: List of the weights dict for each image in the batch. + """ + + """ + weights = [] + for n in range(inputs.shape[0]): + weightsChannel = [] + for c in range(inputs.shape[1]): + weightsImage = dict() + for i in self.indexes: + iTuple = (i[0], i[1], i[2]) + weightsImage[iTuple] = dict() + for j in self.indexes: + jTuple = (j[0], j[1], j[2]) + if iTuple in self.distances and jTuple in self.distances[i]: + brightness = ( + inputs[n][c][i[0]][i[1]][i[2]] + - inputs[n][c][j[0]][j[1]][j[2]] + ) ** 2 + brightness = math.exp(-brightness / self.o_i**2) + weightsImage[iTuple][jTuple] = ( + self.distances[iTuple][jTuple] * brightness + ) + + weightsChannel.append(weightsImage) + + weights.append(weightsChannel) + + return weights + + """ + + # Compute the brightness distance of the pixels + flatted_inputs = inputs.view( + inputs.shape[0], inputs.shape[1], -1 + ) # (N, C, H*W*D) + I_diff = torch.subtract( + flatted_inputs.unsqueeze(3), flatted_inputs.unsqueeze(2) + ) # (N, C, H*W*D, H*W*D) + masked_I_diff = torch.mul(I_diff, self.mask) # (N, C, H*W*D, H*W*D) + squared_I_diff = torch.pow(masked_I_diff, 2) # (N, C, H*W*D, H*W*D) + + W_I = torch.exp( + torch.neg(torch.div(squared_I_diff, self.o_i)) + ) # (N, C, H*W*D, H*W*D) + W_I = torch.mul(W_I, self.mask) # (N, C, H*W*D, H*W*D) + + # Get the spatial distance of the pixels + unsqueezed_W_X = self.W_X.view( + 1, 1, self.W_X.shape[0], self.W_X.shape[1] + ) # (1, 1, H*W*D, H*W*D) + + W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) + return W diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index afc16bd3..4eaddb93 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -14,6 +14,7 @@ from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ +from napari_cellseg3d.code_models.models.model_WNet import WNet_ from napari_cellseg3d.utils import LOGGER @@ -28,6 +29,7 @@ # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS_, "SwinUNetR": SwinUNETR_, + "WNet": WNet_, # "test" : DO NOT USE, reserved for testing } @@ -71,10 +73,12 @@ class ModelInfo: Args: name (str): name of the model model_input_size (Optional[List[int]]): input size of the model + num_classes (int): number of classes for the model """ name: str = next(iter(MODEL_LIST)) model_input_size: Optional[List[int]] = None + num_classes: int = 2 def get_model(self): try: @@ -240,3 +244,21 @@ class TrainingWorkerConfig: sample_size: List[int] = None do_augmentation: bool = True deterministic_config: DeterministicConfig = DeterministicConfig() + + +################ +# CRF config for WNet +################ + + +@dataclass +class WNetCRFConfig: + "Class to store parameters of WNet CRF post processing" + + # CRF + sa = 10 # 50 + sb = 10 + sg = 1 + w1 = 10 # 50 + w2 = 10 + n_iter = 5 From 84f4227fb1dbbec2038ab800d15636b8514677c8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 20 Apr 2023 11:12:59 +0200 Subject: [PATCH 375/577] Specify no grad in inference --- napari_cellseg3d/code_models/model_workers.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 7a45c47e..6bc088e6 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -490,16 +490,17 @@ def model_output_wrapper(inputs): result = model(inputs) return post_process_transforms(result) - outputs = sliding_window_inference( - inputs, - roi_size=window_size, - sw_batch_size=1, # TODO add param - predictor=model_output_wrapper, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) + with torch.no_grad(): + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) except Exception as e: logger.error(e, exc_info=True) logger.debug("failed to run sliding window inference") @@ -1412,16 +1413,17 @@ def train(self): ) self.log("Performing validation...") try: - val_outputs = sliding_window_inference( - val_inputs, - roi_size=size, - sw_batch_size=self.config.batch_size, - predictor=model, - overlap=0.25, - sw_device=self.config.device, - device=self.config.device, - progress=True, - ) + with torch.no_grad(): + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) except Exception as e: self.raise_error(e, "Error during validation") logger.debug( From b370b8c31ec6be39fb98246056c87b1d7c35a63a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 22 Apr 2023 14:12:32 +0200 Subject: [PATCH 376/577] First functional WNet inference, no CRF --- napari_cellseg3d/code_models/model_workers.py | 46 +++++++++++---- .../code_models/models/model_WNet.py | 3 +- .../code_plugins/plugin_model_inference.py | 57 +++++++++++-------- 3 files changed, 71 insertions(+), 35 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 6bc088e6..33f0ee12 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -199,7 +199,7 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None - stats: ImageStats = None + stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -541,7 +541,10 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - semantic_labels = np.swapaxes(semantic_labels, 0, 2) + total_dims = len(semantic_labels.shape) - 3 + semantic_labels = np.swapaxes( + semantic_labels, 0 + total_dims, 2 + total_dims + ) return InferenceResult( image_id=i + 1, @@ -582,8 +585,10 @@ def save_image( ): if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" + filetype = self.config.filetype else: original_filename = "_" + filetype = "" time = utils.get_date_time() @@ -594,7 +599,7 @@ def save_image( + original_filename + self.config.model_info.name + f"_{time}_" - + self.config.filetype + + filetype ) try: imwrite(file_path, image) @@ -619,22 +624,35 @@ def aniso_transform(self, image): else: return image - def instance_seg(self, to_instance, image_id=0, original_filename="layer"): + def instance_seg( + self, to_instance, image_id=0, original_filename="layer", channel=None + ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method instance_labels = method.run_method(image=to_instance) + if channel is not None: + channel_id = f"_{channel}" + else: + channel_id = "" + + if self.config.filetype == "": + filetype = "" + else: + filetype = "_" + self.config.filetype + instance_filepath = ( self.config.results_path + "/" + f"Instance_seg_labels_{image_id}_" + original_filename + + channel_id + "_" + self.config.model_info.name - + f"_{utils.get_date_time()}_" - + self.config.filetype + + f"_{utils.get_date_time()}" + + filetype ) imwrite(instance_filepath, instance_labels) @@ -699,13 +717,21 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, stats = self.get_instance_result(out, from_layer=True) + instance_labels_results = [] + stats_results = [] + + for channel in out: + instance_labels, stats = self.get_instance_result( + channel, from_layer=True + ) + instance_labels_results.append(instance_labels) + stats_results.append(stats) return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels, + instance_labels=instance_labels_results, from_layer=True, - stats=stats, + stats=stats_results, ) # @thread_worker(connect={"errored": self.raise_error}) @@ -762,7 +788,7 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=dims, + input_img_size=[dims, dims, dims], device=self.config.device, num_classes=self.config.model_info.num_classes, ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 63a91b10..dffa3b44 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,8 @@ def __init__( ) def forward(self, x): - """Forward pass of the W-Net model.""" + """Forward ENCODER pass of the W-Net model. + Done this way to allow inference on the encoder only when called by sliding_window_inference.""" enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 44e70a76..522f91bb 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -733,37 +733,46 @@ def on_yield(self, result: InferenceResult): ) if result.instance_labels is not None: - labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + for i, labels in enumerate(result.instance_labels): + # labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - stats = result.stats + from napari_cellseg3d.utils import LOGGER as log - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + log.debug(f"len stats : {len(result.stats)}") - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + for i, stats in enumerate(result.stats): + # stats = result.stats - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) + + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) + + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) except Exception as e: self.on_error(e) From 13faa9c9e3f22ea616fb6d7b8c5d853539312bb6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:48:12 +0200 Subject: [PATCH 377/577] Create test_models.py --- napari_cellseg3d/_tests/test_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 napari_cellseg3d/_tests/test_models.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py new file mode 100644 index 00000000..e2ba32e0 --- /dev/null +++ b/napari_cellseg3d/_tests/test_models.py @@ -0,0 +1,13 @@ +from napari_cellseg3d.config import MODEL_LIST + + +def test_model_list(): + for model_name in MODEL_LIST.keys(): + dims = 128 + test = MODEL_LIST[model_name]( + input_img_size=[dims, dims, dims], + in_channels=1, + out_channels=1, + dropout_prob=0.3, + ) + assert isinstance(test, MODEL_LIST[model_name]) From d16273e586263e8fd17bf2f142b7002aaa5c42d4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:42:56 +0200 Subject: [PATCH 378/577] Run full suite of pre-commit hooks --- docs/res/guides/custom_model_template.rst | 2 -- napari_cellseg3d/code_models/model_instance_seg.py | 2 ++ napari_cellseg3d/code_models/model_workers.py | 5 ++--- napari_cellseg3d/code_models/models/model_SwinUNetR.py | 1 + napari_cellseg3d/code_models/models/model_WNet.py | 3 ++- napari_cellseg3d/code_models/models/wnet/crf.py | 8 +++----- napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py | 8 ++++---- napari_cellseg3d/config.py | 1 - napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 - 9 files changed, 14 insertions(+), 17 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index 9bad49b0..218795b1 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -9,5 +9,3 @@ To add a custom model, you will need a **.py** file with the following structure **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute :: - - diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 0c87a2df..cc7fac90 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -12,6 +12,8 @@ # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 33f0ee12..dce2f452 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1,8 +1,8 @@ import platform +import typing as t from dataclasses import dataclass from math import ceil from pathlib import Path -import typing as t import numpy as np import torch @@ -40,10 +40,9 @@ ) from monai.utils import set_determinism +# from napari.qt.threading import thread_worker # threads from napari.qt.threading import GeneratorWorker - -# from napari.qt.threading import thread_worker from napari.qt.threading import WorkerBaseSignals # Qt diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index f38409b8..05819e22 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,5 @@ from monai.networks.nets import SwinUNETR + from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index dffa3b44..750b8bdb 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -22,7 +22,8 @@ def __init__( def forward(self, x): """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference.""" + Done this way to allow inference on the encoder only when called by sliding_window_inference. + """ enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index ca11fba2..2ac0875d 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,11 +12,9 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - unary_from_softmax, - create_pairwise_gaussian, - create_pairwise_bilateral, -) +from pydensecrf.utils import create_pairwise_bilateral +from pydensecrf.utils import create_pairwise_gaussian +from pydensecrf.utils import unary_from_softmax __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 6a625355..4e84579f 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -1,15 +1,15 @@ """ Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. -The implementation was adapted and approximated to reduce computational and memory cost. +The implementation was adapted and approximated to reduce computational and memory cost. This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. """ import math + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -import numpy as np from scipy.stats import norm __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" @@ -56,7 +56,7 @@ def __init__(self, data_shape, device, o_i, o_x, radius=None): # self.distances, self.indexes = self.get_distances() """ - + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 4eaddb93..43f961f4 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -15,7 +15,6 @@ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ from napari_cellseg3d.code_models.models.model_WNet import WNet_ - from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..48249a94 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,5 +1,4 @@ import os - import napari import numpy as np import scipy.ndimage as ndimage From 858c1e976c252ff9b255c39fa469a1dc8f628c84 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:27:18 +0200 Subject: [PATCH 379/577] Patch for tests action + style --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/code_models/model_instance_seg.py | 6 ++++-- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 + napari_cellseg3d/utils.py | 1 + 5 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index ea0a1e46..88a67ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -16,6 +16,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet workflow_dispatch: jobs: diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index cc7fac90..2f10aa1f 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -12,14 +12,16 @@ # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 750b8bdb..4a9ff70d 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -11,7 +11,7 @@ def __init__( out_channels=1, num_classes=2, device="cpu", - **kwargs + **kwargs, ): super().__init__( device=device, diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 48249a94..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ import os + import napari import numpy as np import scipy.ndimage as ndimage diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 5683c541..9fbe6d7a 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,6 +2,7 @@ import warnings from datetime import datetime from pathlib import Path + import numpy as np from monai.transforms import Zoom from skimage import io From c65e4f5530e978645197c27b618f6c29336c731f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 16:03:29 +0200 Subject: [PATCH 380/577] Add softNCuts basic test --- napari_cellseg3d/_tests/test_models.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index e2ba32e0..9280b230 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,3 +1,6 @@ +import torch + +from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST @@ -11,3 +14,20 @@ def test_model_list(): dropout_prob=0.3, ) assert isinstance(test, MODEL_LIST[model_name]) + + +def test_soft_ncuts_loss(): + dims = 8 + labels = torch.rand([1, 1, dims, dims, dims]) + + loss = SoftNCutsLoss( + data_shape=[dims, dims, dims], + device="cpu", + o_i=4, + o_x=4, + radius=2, + ) + + res = loss.forward(labels, labels) + assert isinstance(res, torch.Tensor) + # assert res > 0 From d74576ef5f8edbaf1004c37ce8012df97eee34f7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 09:41:15 +0200 Subject: [PATCH 381/577] Added crf Co-Authored-By: Nevexios <72894299+nevexios@users.noreply.github.com> --- napari_cellseg3d/code_models/crf.py | 122 ++++++++++++++++++++++++++++ pyproject.toml | 3 + 2 files changed, 125 insertions(+) create mode 100644 napari_cellseg3d/code_models/crf.py diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py new file mode 100644 index 00000000..13f489c7 --- /dev/null +++ b/napari_cellseg3d/code_models/crf.py @@ -0,0 +1,122 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" + +from warnings import warn + +import numpy as np + +try: + import pydensecrf.densecrf as dcrf + from pydensecrf.utils import create_pairwise_bilateral + from pydensecrf.utils import create_pairwise_gaussian + from pydensecrf.utils import unary_from_softmax + + CRF_INSTALLED = True +except ImportError: + warn( + "pydensecrf not installed, CRF post-processing will not be available. " + "Please install by running pip install cellseg3d[crf]" + ) + CRF_INSTALLED = False + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) diff --git a/pyproject.toml b/pyproject.toml index d2a2adbb..d9a46ccf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ profile = "black" line_length = 79 [project.optional-dependencies] +crf = [ +"git+https://github.com/lucasb-eyer/pydensecrf.git", +] dev = [ "isort", "black", From 56f50e0262a37ef095e7f8875b2aaedaf8784fb6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 10:08:46 +0200 Subject: [PATCH 382/577] More pre-commit checks --- .pre-commit-config.yaml | 10 +-- napari_cellseg3d/_tests/fixtures.py | 6 +- napari_cellseg3d/_tests/test_plugin_utils.py | 6 +- napari_cellseg3d/_tests/test_utils.py | 25 ++++--- .../_tests/test_weight_download.py | 6 +- napari_cellseg3d/code_models/crf.py | 11 +-- .../code_models/model_framework.py | 10 ++- .../code_models/model_instance_seg.py | 13 ++-- napari_cellseg3d/code_models/model_workers.py | 11 +-- .../code_models/models/wnet/crf.py | 8 ++- napari_cellseg3d/code_plugins/plugin_base.py | 3 +- .../code_plugins/plugin_convert.py | 16 ++--- napari_cellseg3d/code_plugins/plugin_crop.py | 5 +- .../code_plugins/plugin_model_inference.py | 17 +++-- .../code_plugins/plugin_model_training.py | 8 +-- .../code_plugins/plugin_review.py | 13 ++-- .../code_plugins/plugin_review_dock.py | 5 +- .../code_plugins/plugin_utilities.py | 4 +- napari_cellseg3d/config.py | 5 +- .../dev_scripts/artefact_labeling.py | 3 +- napari_cellseg3d/dev_scripts/convert.py | 3 +- .../dev_scripts/correct_labels.py | 3 +- napari_cellseg3d/interface.py | 67 +++++++++---------- napari_cellseg3d/utils.py | 59 ++++++++-------- pyproject.toml | 7 +- 25 files changed, 166 insertions(+), 158 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..61ecaae5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,11 +5,11 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] +# - repo: https://github.com/pycqa/isort +# rev: 5.12.0 +# hooks: +# - id: isort +# args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index bd6b0ac7..b3044799 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -1,7 +1,7 @@ -import warnings - from qtpy.QtWidgets import QTextEdit +from napari_cellseg3d.utils import LOGGER as logger + class LogFixture(QTextEdit): """Fixture for testing, replaces napari_cellseg3d.interface.Log in model_workers during testing""" @@ -13,7 +13,7 @@ def print_and_log(self, text, printing=None): print(text) def warn(self, warning): - warnings.warn(warning) + logger.warning(warning) def error(self, e): raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 7908e8b4..584be4d7 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -3,8 +3,10 @@ import numpy as np from tifffile import imread -from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS +from napari_cellseg3d.code_plugins.plugin_utilities import ( + UTILITIES_WIDGETS, + Utilities, +) def test_utils_plugin(make_napari_viewer): diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index dc57b940..f2a9d32c 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,8 +1,7 @@ import os -import warnings +from functools import partial import numpy as np -import pytest import torch from napari_cellseg3d import utils @@ -33,6 +32,10 @@ def test_fill_list_in_between(): assert utils.fill_list_in_between(list, 2, "") == res + fill = partial(utils.fill_list_in_between, n=2, fill_value="") + + assert fill(list) == res + def test_align_array_sizes(): im = np.zeros((128, 512, 256)) @@ -79,15 +82,15 @@ def test_get_padding_dim(): tensor = torch.randn(2000, 30, 40) size = tensor.size() - warn = warnings.warn( - "Warning : a very large dimension for automatic padding has been computed.\n" - "Ensure your images are of an appropriate size and/or that you have enough memory." - "The padding value is currently 2048." - ) - - pad = utils.get_padding_dim(size) - - pytest.warns(warn, (lambda: utils.get_padding_dim(size))) + # warn = logger.warning( + # "Warning : a very large dimension for automatic padding has been computed.\n" + # "Ensure your images are of an appropriate size and/or that you have enough memory." + # "The padding value is currently 2048." + # ) + # + # pad = utils.get_padding_dim(size) + # + # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) assert pad == [2048, 32, 64] diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index 1bcb40d7..b9d4abe5 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,5 +1,7 @@ -from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR -from napari_cellseg3d.code_models.model_workers import WeightsDownloader +from napari_cellseg3d.code_models.model_workers import ( + PRETRAINED_WEIGHTS_DIR, + WeightsDownloader, +) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 13f489c7..fc1e0b90 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -16,15 +16,18 @@ try: import pydensecrf.densecrf as dcrf - from pydensecrf.utils import create_pairwise_bilateral - from pydensecrf.utils import create_pairwise_gaussian - from pydensecrf.utils import unary_from_softmax + from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, + ) CRF_INSTALLED = True except ImportError: warn( "pydensecrf not installed, CRF post-processing will not be available. " - "Please install by running pip install cellseg3d[crf]" + "Please install by running pip install cellseg3d[crf]", + stacklevel=1, ) CRF_INSTALLED = False diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index d541b486..37fc6a49 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import napari @@ -12,7 +11,6 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -135,11 +133,11 @@ def save_log(self): f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) else: - warnings.warn(f"No logger defined : Log is {self.log}") + logger.warning(f"No logger defined : Log is {self.log}") def save_log_to_path(self, path): """Saves the worker log to a specific path. Cannot be used with connect. @@ -161,7 +159,7 @@ def save_log_to_path(self, path): f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) @@ -170,7 +168,7 @@ def display_status_report(self): (usually when starting a worker)""" # if self.container_report is None or self.log is None: - # warnings.warn( + # logger.warning( # "Status report widget has been closed. Trying to re-instantiate..." # ) # self.container_report = QWidget() diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index 2f10aa1f..d551920d 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,11 +1,11 @@ from dataclasses import dataclass +from functools import partial from typing import List import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label -from skimage.measure import regionprops +from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed @@ -14,9 +14,8 @@ from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis # from skimage.measure import marching_cubes # from skimage.measure import mesh_surface_area @@ -399,8 +398,10 @@ def sphericity(region): volume = [region.area for region in properties] - def fill(lst, n=len(properties) - 1): - return fill_list_in_between(lst, n, "") + # def fill(lst, n=len(properties) - 1): + # return fill_list_in_between(lst, n, "") + + fill = partial(fill_list_in_between, n=len(properties) - 1, fill_value="") if len(volume_image.flatten()) != 0: ratio = fill([np.sum(volume) / len(volume_image.flatten())]) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index dce2f452..65b1c80a 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -42,8 +42,7 @@ # from napari.qt.threading import thread_worker # threads -from napari.qt.threading import GeneratorWorker -from napari.qt.threading import WorkerBaseSignals +from napari.qt.threading import GeneratorWorker, WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -51,8 +50,12 @@ from tqdm import tqdm # local -from napari_cellseg3d.code_models.model_instance_seg import ImageStats -from napari_cellseg3d.code_models.model_instance_seg import volume_stats +from napari_cellseg3d import config, utils +from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.model_instance_seg import ( + ImageStats, + volume_stats, +) logger = utils.LOGGER diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index 2ac0875d..004db3a1 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,9 +12,11 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import create_pairwise_bilateral -from pydensecrf.utils import create_pairwise_gaussian -from pydensecrf.utils import unary_from_softmax +from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, +) __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 2cb3581b..e7b97e01 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from pathlib import Path @@ -404,7 +403,7 @@ def load_dataset_paths(self): file_paths = sorted(Path(directory).glob("*" + filetype)) if len(file_paths) == 0: - warnings.warn( + logger.warning( f"The folder does not contain any compatible {filetype} files.\n" f"Please check the validity of the folder and images." ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index a847ebf7..0bff4cae 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,18 +1,18 @@ -import warnings from pathlib import Path import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import threshold -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceWidgets, + clear_small_objects, + threshold, + to_semantic, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -84,7 +84,7 @@ def show_result(viewer, layer, image, name): logger.debug("Added resulting label layer") viewer.add_labels(image, name=name) else: - warnings.warn( + logger.warning( f"Results not shown, unsupported layer type {type(layer)}" ) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 1647e858..d82df475 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import napari @@ -245,7 +244,7 @@ def _start(self): # maybe use singletons or make docked widgets attributes that are hidden upon opening if not self._check_ready(): - warnings.warn("Please select at least one valid layer !") + logger.warning("Please select at least one valid layer !") return # self._viewer.window.remove_dock_widget(self.parent()) # no need to close utils ? @@ -329,7 +328,7 @@ def add_isotropic_layer( self, layer, colormap="inferno", - contrast_lim=[200, 1000], # TODO generalize ? + contrast_lim=(200, 1000), # TODO generalize ? opacity=0.7, visible=True, ): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 522f91bb..eab16c8b 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,4 +1,3 @@ -import warnings from functools import partial import napari @@ -9,10 +8,16 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceMethod, + InstanceWidgets, +) +from napari_cellseg3d.code_models.model_workers import ( + InferenceResult, + InferenceWorker, +) + +logger = utils.LOGGER class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -539,7 +544,7 @@ def start(self): if not self._check_results_path(save_path): msg = f"ERROR: please set valid results path. Current path is {save_path}" self.log.print_and_log(msg) - warnings.warn(msg) + logger.warning(msg) else: if self.results_path is None: self.results_path = save_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 132c9531..88991f43 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,5 +1,4 @@ import shutil -import warnings from functools import partial from pathlib import Path @@ -418,8 +417,7 @@ def check_ready(self): if self.images_filepaths != [] and self.labels_filepaths != []: return True else: - warnings.formatwarning = utils.format_Warning - warnings.warn("Image and label paths are not correctly set") + logger.warning("Image and label paths are not correctly set") return False def _build(self): @@ -787,7 +785,7 @@ def start(self): if not self.check_ready(): # issues a warning if not ready err = "Aborting, please set all required paths" self.log.print_and_log(err) - warnings.warn(err) + logger.warning(err) return if self.worker is not None: @@ -1043,7 +1041,7 @@ def _make_csv(self): size_column = range(1, self.worker_config.max_epochs + 1) if len(self.loss_values) == 0 or self.loss_values is None: - warnings.warn("No loss values to add to csv !") + logger.warning("No loss values to add to csv !") return self.df = pd.DataFrame( diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index e3e05f6c..235595e4 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import matplotlib.pyplot as plt @@ -20,7 +19,6 @@ from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -180,10 +178,11 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: - warnings.warn( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None: + if cfg.image.shape != cfg.labels.shape: + logger.warning( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): @@ -237,7 +236,7 @@ def run_review(self): self._reset() previous_viewer.close() except ValueError as e: - warnings.warn( + logger.warning( f"An exception occurred : {e}. Please ensure you have entered all required parameters." ) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index c09c376f..8753a642 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,4 +1,3 @@ -import warnings from datetime import datetime, timedelta from pathlib import Path @@ -16,7 +15,7 @@ GUI_MINIMUM_HEIGHT = 300 TIMER_FORMAT = "%H:%M:%S" - +logger = utils.LOGGER """ plugin_dock.py ==================================== @@ -261,7 +260,7 @@ def update_dm(self, slice_num): def button_func(self): # updates csv every time you press button... if self.viewer.dims.ndisplay != 2: # TODO test if undefined behaviour or if okay - warnings.warn("Please switch back to 2D mode !") + logger.warning("Please switch back to 2D mode !") return self.update_time_csv() diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 45c0c119..462ee450 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,9 +2,7 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget # local import napari_cellseg3d.interface as ui diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 43f961f4..2b38eb29 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -1,5 +1,4 @@ import datetime -import warnings from dataclasses import dataclass from pathlib import Path from typing import List, Optional @@ -84,9 +83,9 @@ def get_model(self): return MODEL_LIST[self.name] except KeyError as e: msg = f"Model {self.name} is not defined" - warnings.warn(msg) logger.warning(msg) - raise KeyError(e) + logger.warning(msg) + raise KeyError from e @staticmethod def get_model_name_list(): diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..b4712aec 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -4,8 +4,7 @@ import numpy as np import scipy.ndimage as ndimage from skimage.filters import threshold_otsu -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from napari_cellseg3d.code_models.model_instance_seg import binary_watershed diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index 479a07dd..641de627 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,8 +2,7 @@ import os import numpy as np -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" # output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2f079d09..2ab60332 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -8,8 +8,7 @@ import numpy as np import scipy.ndimage as ndimage from napari.qt.threading import thread_worker -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 9a100dc2..36fc9aab 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,5 +1,4 @@ import threading -import warnings from functools import partial from typing import List, Optional @@ -9,32 +8,30 @@ from qtpy import QtCore # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt -from qtpy.QtCore import QUrl -from qtpy.QtGui import QCursor -from qtpy.QtGui import QDesktopServices -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QCheckBox -from qtpy.QtWidgets import QComboBox -from qtpy.QtWidgets import QDoubleSpinBox -from qtpy.QtWidgets import QFileDialog -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QGroupBox -from qtpy.QtWidgets import QHBoxLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QMenu -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QRadioButton -from qtpy.QtWidgets import QScrollArea -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QSlider -from qtpy.QtWidgets import QSpinBox -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtCore import QObject, Qt, QUrl +from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLayout, + QLineEdit, + QMenu, + QPushButton, + QRadioButton, + QScrollArea, + QSizePolicy, + QSlider, + QSpinBox, + QTextEdit, + QVBoxLayout, + QWidget, +) # Local from napari_cellseg3d import utils @@ -288,10 +285,10 @@ def print_and_log(self, text, printing=True): self.lock.release() def warn(self, warning): - """Show warnings.warn from another thread""" + """Show logger.warning from another thread""" self.lock.acquire() try: - warnings.warn(warning) + logger.warning(warning) finally: self.lock.release() @@ -536,7 +533,7 @@ def _build_container(self): ) def _warn_outside_bounds(self, default): - warnings.warn( + logger.warning( f"Default value {default} was outside of the ({self.minimum()}:{self.maximum()}) range" ) @@ -581,7 +578,7 @@ def slider_value(self): try: return self.value() / self._divide_factor except ZeroDivisionError as e: - raise ZeroDivisionError( + raise ZeroDivisionError from ( f"Divide factor cannot be 0 for Slider : {e}" ) @@ -791,8 +788,8 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: - warnings.warn("Please select a valid layer !") - return None + logger.warning("Please select a valid layer !") + return return self._viewer.layers[self.layer_name()].data @@ -1188,7 +1185,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, - possible_paths: list = [], + possible_paths: list = (), filetype: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1212,7 +1209,7 @@ def open_file_dialog( def open_folder_dialog( widget, - possible_paths: list = [], + possible_paths: list = (), ): default_path = utils.parse_default_path(possible_paths) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 9fbe6d7a..f3fc09ba 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,5 +1,4 @@ import logging -import warnings from datetime import datetime from pathlib import Path @@ -234,7 +233,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # warnings.warn( + # logger.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) @@ -242,7 +241,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): pad = 2**n n += 1 if pad >= 256: - warnings.warn( + LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." f"The padding value is currently {pad}." @@ -342,14 +341,14 @@ def annotation_to_input(label_ermito): # pass -def fill_list_in_between(lst, n, elem): +def fill_list_in_between(lst, n, fill_value): """Fills a list with n * elem between each member of list. Example with list = [1,2,3], n=2, elem='&' : returns [1, &, &,2,&,&,3,&,&] Args: lst: list to fill n: number of elements to add - elem: added n times after each element of list + fill_value: added n times after each element of list Returns : Filled list @@ -358,13 +357,13 @@ def fill_list_in_between(lst, n, elem): for i in range(len(lst)): temp_list = [lst[i]] while len(temp_list) < n + 1: - temp_list.append(elem) + temp_list.append(fill_value) if i < len(lst) - 1: new_list += temp_list else: new_list.append(lst[i]) for _j in range(n): - new_list.append(elem) + new_list.append(fill_value) return new_list return None @@ -535,26 +534,26 @@ def select_train_data(dataframe, ori_imgs, label_imgs, ori_filenames): return np.array(train_ori_imgs), np.array(train_label_imgs) -def format_Warning(message, category, filename, lineno, line=""): - """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` - - Args: - message: warning message - category: which type of warning has been raised - filename: file - lineno: line number - line: unused - - Returns: format - - """ - return ( - str(filename) - + ":" - + str(lineno) - + ": " - + category.__name__ - + ": " - + str(message) - + "\n" - ) +# def format_Warning(message, category, filename, lineno, line=""): +# """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` +# +# Args: +# message: warning message +# category: which type of warning has been raised +# filename: file +# lineno: line number +# line: unused +# +# Returns: format +# +# """ +# return ( +# str(filename) +# + ":" +# + str(lineno) +# + ": " +# + category.__name__ +# + ": " +# + str(message) +# + "\n" +# ) diff --git a/pyproject.toml b/pyproject.toml index d9a46ccf..8e7187f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,12 @@ where = ["."] "*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] [tool.ruff] -# Never enforce `E501` (line length violations). +select = [ + "E", "F", "W", + "I", + "B", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) ignore = ["E501", "E741"] [tool.black] From 543b3f81453d6d2927ea9c3dfeea4ccef5b7aea8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:29:42 +0200 Subject: [PATCH 383/577] Functional CRF --- napari_cellseg3d/_tests/test_models.py | 39 +++ napari_cellseg3d/code_models/crf.py | 98 ++++++- .../code_models/model_instance_seg.py | 6 +- napari_cellseg3d/code_models/model_workers.py | 48 +++- napari_cellseg3d/code_plugins/plugin_base.py | 23 +- .../code_plugins/plugin_convert.py | 194 ++++--------- napari_cellseg3d/code_plugins/plugin_crf.py | 262 ++++++++++++++++++ napari_cellseg3d/code_plugins/plugin_crop.py | 7 +- .../code_plugins/plugin_model_inference.py | 32 ++- .../code_plugins/plugin_utilities.py | 15 +- napari_cellseg3d/config.py | 16 ++ napari_cellseg3d/interface.py | 19 +- napari_cellseg3d/utils.py | 81 +++++- 13 files changed, 671 insertions(+), 169 deletions(-) create mode 100644 napari_cellseg3d/code_plugins/plugin_crf.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 9280b230..1fc15872 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,9 +1,18 @@ +import numpy as np import torch +from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +def test_correct_shape_for_crf(): + test = np.random.rand(1, 1, 8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + test = np.random.rand(8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + + def test_model_list(): for model_name in MODEL_LIST.keys(): dims = 128 @@ -31,3 +40,33 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) # assert res > 0 + + +def test_crf(qtbot): + dims = 8 + mock_image = np.random.rand(1, dims, dims, dims) + mock_label = np.random.rand(2, dims, dims, dims) + + crf = CRFWorker(mock_image, mock_label) + + def on_yield(result): + assert isinstance(result, np.ndarray) + assert result.shape[-3:] == mock_label.shape[-3:] + + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) + + mock_image = mock_image[0] + mock_label = mock_label[0] + + crf = CRFWorker(mock_image, mock_label) + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index fc1e0b90..a0146a5e 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -9,11 +9,8 @@ Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. """ - from warnings import warn -import numpy as np - try: import pydensecrf.densecrf as dcrf from pydensecrf.utils import ( @@ -31,6 +28,12 @@ ) CRF_INSTALLED = False + +import numpy as np +from napari.qt.threading import GeneratorWorker + +from napari_cellseg3d.config import CRFConfig + __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ "Yves Paychère", @@ -49,6 +52,16 @@ ] +def correct_shape_for_crf(image): + if len(image.shape) == 4: + return image + if len(image.shape) > 4: + image = np.squeeze(image, axis=0) + if len(image.shape) < 4: + image = np.expand_dims(image, axis=0) + return correct_shape_for_crf(image) + + def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): """CRF post-processing step for the W-Net, applied to a batch of images. @@ -62,6 +75,8 @@ def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): Returns: np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. """ + if not CRF_INSTALLED: + return None return np.stack( [ @@ -83,10 +98,16 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + w1 (float): weight of the appearance/bilateral kernel. + w2 (float): weight of the smoothness/gaussian kernel. Returns: np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. """ + + if not CRF_INSTALLED: + return None + d = dcrf.DenseCRF( image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] ) @@ -123,3 +144,74 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): return np.array(Q).reshape( (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) ) + + +def crf_with_config(image, prob, config: CRFConfig = None): + if config is None: + config = CRFConfig() + if image.shape[-3:] != prob.shape[-3:]: + raise ValueError( + f"Image and probability shapes do not match: {image.shape} vs {prob.shape}" + f" (expected {image.shape[-3:]} == {prob.shape[-3:]})" + ) + + image = correct_shape_for_crf(image) + + return crf( + image, + prob, + config.sa, + config.sb, + config.sg, + config.w1, + config.w2, + config.n_iters, + ) + + +class CRFWorker(GeneratorWorker): + """Worker for the CRF post-processing step for the W-Net.""" + + def __init__( + self, + images_list, + labels_list, + config: CRFConfig = None, + log=None, + ): + super().__init__(self._run_crf_job) + + self.images = images_list + self.labels = labels_list + if config is None: + self.config = CRFConfig() + else: + self.config = config + self.log = log + + # TODO(cyril) : add progress bar into log ? or do it in inference + def _run_crf_job(self): + """Runs the CRF post-processing step for the W-Net.""" + if not CRF_INSTALLED: + raise ImportError("pydensecrf is not installed.") + + for image, labels in zip(self.images, self.labels): + if len(image.shape) == 3: + image = np.expand_dims(image, axis=0) + + if len(labels.shape) == 3: + labels = np.expand_dims(labels, axis=0) + + if image.shape[-3:] != labels.shape[-3:]: + raise ValueError("Image and labels must have the same shape.") + + yield crf( + image, + labels, + self.config.sa, + self.config.sb, + self.config.sg, + self.config.w1, + self.config.w2, + n_iter=self.config.n_iters, + ) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index d551920d..d1a03eec 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -65,7 +65,7 @@ def __init__( 1, divide_factor=100, text_label="", - parent=None, + parent=widget_parent, ), ) self.sliders.append(getattr(self, widget)) @@ -76,7 +76,9 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(text_label="", parent=None), + ui.DoubleIncrementCounter( + text_label="", parent=widget_parent + ), ) self.counters.append(getattr(self, widget)) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 65b1c80a..c7196db7 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -52,6 +52,7 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import crf_with_config from napari_cellseg3d.code_models.model_instance_seg import ( ImageStats, volume_stats, @@ -201,6 +202,7 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None + crf_results: np.array = None stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -528,7 +530,8 @@ def create_inference_result( self, semantic_labels, instance_labels, - from_layer: bool, + crf_results=None, + from_layer: bool = False, original=None, stats=None, i=0, @@ -543,15 +546,19 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - total_dims = len(semantic_labels.shape) - 3 + extra_dims = len(semantic_labels.shape) - 3 semantic_labels = np.swapaxes( - semantic_labels, 0 + total_dims, 2 + total_dims + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims ) return InferenceResult( image_id=i + 1, original=original, instance_labels=instance_labels, + crf_results=crf_results, stats=stats, result=semantic_labels, model_name=self.config.model_info.name, @@ -584,6 +591,7 @@ def save_image( image, from_layer=False, i=0, + additional_info="", ): if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" @@ -597,7 +605,7 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"Prediction_{i+1}" + + f"{additional_info}_Prediction_{i+1}" + original_filename + self.config.model_info.name + f"_{time}_" @@ -679,6 +687,15 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): self.save_image(out, i=i) instance_labels, stats = self.get_instance_result(out, i=i) + if self.config.use_crf: + try: + crf_results = self.run_crf(inputs, out, image_id=i) + + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + crf_results = None + else: + crf_results = None original = np.array(inf_data["image"]).astype(np.float32) @@ -687,12 +704,29 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): return self.create_inference_result( out, instance_labels, + crf_results, from_layer=False, original=original, stats=stats, i=i, ) + def run_crf(self, image, labels, image_id=0): + self.log(f"IMAGE SHAPE : {image.shape}") + self.log(f"LABEL SHAPE : {labels.shape}") + + try: + crf_results = crf_with_config( + image, labels, config=self.config.crf_config + ) + self.save_image( + crf_results, i=image_id, additional_info="CRF", from_layer=True + ) + return crf_results + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + return None + def stats_csv(self, instance_labels): if self.config.compute_stats: stats = volume_stats(instance_labels) @@ -729,9 +763,15 @@ def inference_on_layer(self, image, model, post_process_transforms): instance_labels_results.append(instance_labels) stats_results.append(stats) + if self.config.use_crf: + crf_results = self.run_crf(image, out) + else: + crf_results = None + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels_results, + crf_results=crf_results, from_layer=True, stats=stats_results, ) diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index e7b97e01..26da7a42 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -46,15 +46,15 @@ def __init__( self.image_path = None """str: path to image folder""" - self.show_image_io = loads_images + self._show_image_io = loads_images self.label_path = None """str: path to label folder""" - self.show_label_io = loads_labels + self._show_label_io = loads_labels self.results_path = None """str: path to results folder""" - self.show_results_io = has_results + self._show_results_io = has_results self._default_path = [self.image_path, self.label_path] @@ -117,7 +117,6 @@ def show_menu(_, event): def _build_io_panel(self): self.io_panel = ui.GroupedWidget("Data") self.save_label = ui.make_label("Save location :", parent=self) - # self.io_panel.setToolTip("IO Panel") ui.add_widgets( @@ -139,25 +138,25 @@ def _build_io_panel(self): return self.io_panel def _remove_unused(self): - if not self.show_label_io: + if not self._show_label_io: self.labels_filewidget = None self.label_layer_loader = None - if not self.show_image_io: + if not self._show_image_io: self.image_layer_loader = None self.image_filewidget = None - if not self.show_results_io: + if not self._show_results_io: self.results_filewidget = None def _set_io_visibility(self): ################## # Show when layer is selected - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_layer_loader, self.layer_choice) else: self._hide_io_element(self.image_layer_loader) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.label_layer_loader, self.layer_choice) else: self._hide_io_element(self.label_layer_loader) @@ -167,15 +166,15 @@ def _set_io_visibility(self): f = self.folder_choice self._show_io_element(self.filetype_choice, f) - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_filewidget, f) else: self._hide_io_element(self.image_filewidget) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.labels_filewidget, f) else: self._hide_io_element(self.labels_filewidget) - if not self.show_results_io: + if not self._show_results_io: self._hide_io_element(self.results_filewidget) self.folder_choice.toggle() diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 0bff4cae..f7b476d0 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -3,7 +3,7 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread import napari_cellseg3d.interface as ui from napari_cellseg3d import utils @@ -15,80 +15,12 @@ ) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -# TODO break down into multiple mini-widgets -# TODO create parent class for utils modules to avoid duplicates - -MAX_W = 200 -MAX_H = 1000 +MAX_W = ui.UTILS_MAX_WIDTH +MAX_H = ui.UTILS_MAX_HEIGHT logger = utils.LOGGER -def save_folder(results_path, folder_name, images, image_paths): - """ - Saves a list of images in a folder - - Args: - results_path: Path to the folder containing results - folder_name: Name of the folder containing results - images: List of images to save - image_paths: list of filenames of images - """ - results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False, parents=True) - - for file, image in zip(image_paths, images): - path = results_folder / Path(file).name - - imwrite( - path, - image, - ) - logger.info(f"Saved processed folder as : {results_folder}") - - -def save_layer(results_path, image_name, image): - """ - Saves an image layer at the specified path - - Args: - results_path: path to folder containing result - image_name: image name for saving - image: data array containing image - - Returns: - - """ - path = str(results_path / Path(image_name)) # TODO flexible filetype - logger.info(f"Saved as : {path}") - imwrite(path, image) - - -def show_result(viewer, layer, image, name): - """ - Adds layers to a viewer to show result to user - - Args: - viewer: viewer to add layer in - layer: type of the original layer the operation was run on, to determine whether it should be an Image or Labels layer - image: the data array containing the image - name: name of the added layer - - Returns: - - """ - if isinstance(layer, napari.layers.Image): - logger.debug("Added resulting image layer") - viewer.add_image(image, name=name) - elif isinstance(layer, napari.layers.Labels): - logger.debug("Added resulting label layer") - viewer.add_labels(image, name=name) - else: - logger.warning( - f"Results not shown, unsupported layer type {type(layer)}" - ) - - class AnisoUtils(BasePluginFolder): """Class to correct anisotropy in images""" @@ -154,31 +86,30 @@ def _start(self): data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) - save_layer( + utils.save_layer( self.results_path, f"isotropic_{layer.name}_{utils.get_date_time()}.tif", isotropic_image, ) - show_result( + utils.show_result( self._viewer, layer, isotropic_image, f"isotropic_{layer.name}", ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -254,27 +185,26 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"cleared_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -336,12 +266,12 @@ def _start(self): data = np.array(layer.data) semantic = to_semantic(data) - save_layer( + utils.save_layer( self.results_path, f"semantic_{layer.name}_{utils.get_date_time()}.tif", semantic, ) - show_result( + utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) elif self.folder_choice.isChecked(): @@ -350,7 +280,7 @@ def _start(self): to_semantic(file, is_file_path=True) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"semantic_results_{utils.get_date_time()}", images, @@ -421,7 +351,7 @@ def _start(self): data = np.array(layer.data) instance = self.instance_widgets.run_method(data) - save_layer( + utils.save_layer( self.results_path, f"instance_{layer.name}_{utils.get_date_time()}.tif", instance, @@ -430,19 +360,18 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - self.instance_widgets.run_method(imread(file)) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.instance_widgets.run_method(imread(file)) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -517,27 +446,26 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"threshold_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py new file mode 100644 index 00000000..3dbd47bb --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -0,0 +1,262 @@ +from functools import partial +from pathlib import Path + +import napari.layers +from qtpy.QtWidgets import QSizePolicy +from tqdm import tqdm + +from napari_cellseg3d import config, utils +from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage +from napari_cellseg3d.utils import LOGGER as logger + + +# TODO add CRF on folder +class CRFParamsWidget(ui.GroupedWidget): + """Use this widget when adding the crf as part of another widget (rather than a standalone widget)""" + + def __init__(self, parent=None): + super().__init__(title="CRF parameters", parent=parent) + ####### + # CRF params # + self.sa_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Alpha std" + ) + self.sb_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Beta std" + ) + self.sg_choice = ui.DoubleIncrementCounter( + default=1, parent=self, text_label="Gamma std" + ) + self.w1_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Weight appearance" + ) + self.w2_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Weight smoothness" + ) + self.n_iter_choice = ui.IntIncrementCounter( + default=5, parent=self, text_label="Number of iterations" + ) + ####### + self._build() + self._set_tooltips() + + def _build(self): + ui.add_widgets( + self.layout, + [ + # self.sa_choice.label, + self.sa_choice, + # self.sb_choice.label, + self.sb_choice, + # self.sg_choice.label, + self.sg_choice, + # self.w1_choice.label, + self.w1_choice, + # self.w2_choice.label, + self.w2_choice, + # self.n_iter_choice.label, + self.n_iter_choice, + ], + ) + self.set_layout() + + def _set_tooltips(self): + self.sa_choice.setToolTip( + "SA : Standard deviation of the Gaussian kernel in the appearance term." + ) + self.sb_choice.setToolTip( + "SB : Standard deviation of the Gaussian kernel in the smoothness term." + ) + self.sg_choice.setToolTip( + "SG : Standard deviation of the Gaussian kernel in the gradient term." + ) + self.w1_choice.setToolTip( + "W1 : Weight of the appearance term in the CRF." + ) + self.w2_choice.setToolTip( + "W2 : Weight of the smoothness term in the CRF." + ) + self.n_iter_choice.setToolTip("Number of iterations of the CRF.") + + def make_config(self): + return config.CRFConfig( + sa=self.sa_choice.value(), + sb=self.sb_choice.value(), + sg=self.sg_choice.value(), + w1=self.w1_choice.value(), + w2=self.w2_choice.value(), + n_iters=self.n_iter_choice.value(), + ) + + +class CRFWidget(BasePluginSingleImage): + def __init__(self, viewer, parent=None): + """ + Create a widget for CRF post-processing. + Args: + viewer: napari viewer to display the widget + parent: parent widget. Defaults to None. + """ + super().__init__(viewer, parent) + self._viewer = viewer + + self.start_button = ui.Button("Start", self._start, parent=self) + self.crf_params_widget = CRFParamsWidget(parent=self) + self.io_panel = self._build_io_panel() + self.io_panel.setVisible(False) + + self.results_filewidget.setVisible(True) + self.label_layer_loader.setVisible(True) + self.label_layer_loader.set_layer_type( + napari.layers.Image + ) # to load all crf-compatible inputs, not int only + self.image_layer_loader.setVisible(True) + self.start_button.setVisible(True) + + self.result_layer = None + self.result_name = None + self.crf_results = [] + + self.results_path = Path.home() / Path("cellseg3d/crf") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._container = ui.ContainerWidget(parent=self, l=11, t=11, r=11) + self.layout = self._container.layout + + self._build() + + self.worker = None + self.log = None + + def _build(self): + self.setMinimumWidth(100) + ui.add_widgets( + self.layout, + [ + self.image_layer_loader, + self.label_layer_loader, + self.save_label, + self.results_filewidget, + ui.make_label(""), + self.crf_params_widget, + ui.make_label(""), + self.start_button, + ], + ) + # self.io_panel.setLayout(self.io_panel.layout) + self.setLayout(self.layout) + + ui.ScrollArea.make_scrollable( + self.layout, self, max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT] + ) + self._container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + return self._container + + def make_config(self): + return self.crf_params_widget.make_config() + + def _check_ready(self): + if len(self.label_layer_loader.layer_list) < 1: + logger.warning("No label layer loaded") + return False + if len(self.image_layer_loader.layer_list) < 1: + logger.warning("No image layer loaded") + return False + + if len(self.label_layer_loader.layer_data().shape) < 3: + logger.warning("Label layer must be 3D") + return False + if len(self.image_layer_loader.layer_data().shape) < 3: + logger.warning("Image layer must be 3D") + return False + if ( + self.label_layer_loader.layer_data().shape[-3:] + != self.image_layer_loader.layer_data().shape[-3:] + ): + logger.warning("Image and label layers must have the same shape!") + return False + + return True + + def run_crf_on_batch(self, images_list: list, labels_list: list, log=None): + self.crf_results = [] + for image, label in zip(images_list, labels_list): + tqdm( + unit="B", + total=len(images_list), + position=0, + file=log, + ) + result = crf_with_config(image, label, self.make_config()) + self.crf_results.append(result) + return self.crf_results + + def _prepare_worker(self, images_list: list, labels_list: list): + self.worker = CRFWorker( + images_list=images_list, + labels_list=labels_list, + config=self.make_config(), + ) + + self.worker.started.connect(self._on_start) + self.worker.yielded.connect(partial(self._on_yield)) + self.worker.errored.connect(partial(self._on_error)) + self.worker.finished.connect(self._on_finish) + + def _start(self): + if not self._check_ready(): + return + + self.result_layer = self.label_layer_loader.layer() + self.result_name = self.label_layer_loader.layer_name() + + self.results_path.mkdir(exist_ok=True, parents=True) + + image_list = [self.image_layer_loader.layer_data()] + labels_list = [self.label_layer_loader.layer_data()] + [logger.debug(f"Image shape: {image.shape}") for image in image_list] + [ + logger.debug(f"Label shape: {labels.shape}") + for labels in labels_list + ] + + self._prepare_worker(image_list, labels_list) + + if self.worker.is_running: # if worker is running, tries to stop + logger.info("Stop request, waiting for previous job to finish") + self.start_button.setText("Stopping...") + self.worker.quit() + else: # once worker is started, update buttons + self.start_button.setText("Running...") + logger.info("Starting CRF...") + self.worker.start() + + def _on_yield(self, result): + self.crf_results.append(result) + + utils.save_layer( + self.results_filewidget.text_field.text(), + str(self.result_name + "_crf.tif"), + result, + ) + self._viewer.add_image( + result, + name="crf_" + self.result_name, + ) + + def _on_start(self): + self.crf_results = [] + + def _on_finish(self): + self.worker = None + + def _on_error(self, error): + logger.error(error) + self.start_button.setText("Start") + self.worker.quit() + self.worker = None diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index d82df475..6e7f91f3 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -174,7 +174,12 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + ui.ScrollArea.make_scrollable( + layout, + self, + max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT], + min_wh=[200, 200], + ) self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index eab16c8b..472cccd8 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -16,6 +16,7 @@ InferenceResult, InferenceWorker, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFParamsWidget logger = utils.LOGGER @@ -195,9 +196,17 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## # instance segmentation widgets self.instance_widgets = InstanceWidgets(parent=self) + self.crf_widgets = CRFParamsWidget(parent=self) self.use_instance_choice = ui.CheckBox( - "Run instance segmentation", func=self._toggle_display_instance + "Run instance segmentation", + func=self._toggle_display_instance, + parent=self, + ) + self.use_crf = ui.CheckBox( + "Use CRF post-processing", + func=self._toggle_display_crf, + parent=self, ) self.save_stats_to_csv_box = ui.CheckBox( @@ -309,6 +318,10 @@ def _toggle_display_thresh(self): self.thresholding_checkbox, self.thresholding_slider.container ) + def _toggle_display_crf(self): + """Shows the choices for CRF post-processing depending on whether :py:attr:`self.use_crf` is checked""" + ui.toggle_visibility(self.use_crf, self.crf_widgets) + def _toggle_display_instance(self): """Shows or hides the options for instance segmentation based on current user selection""" ui.toggle_visibility(self.use_instance_choice, self.instance_widgets) @@ -426,6 +439,8 @@ def _build(self): self.thresholding_slider.container, # thresholding self.use_instance_choice, self.instance_widgets, + self.use_crf, + self.crf_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -437,6 +452,7 @@ def _build(self): self.anisotropy_wdgt.container.setVisible(False) self.thresholding_slider.container.setVisible(False) self.instance_widgets.setVisible(False) + self.crf_widgets.setVisible(False) self.save_stats_to_csv_box.setVisible(False) post_proc_group.setLayout(post_proc_layout) @@ -588,6 +604,8 @@ def start(self): compute_stats=self.save_stats_to_csv_box.isChecked(), post_process_config=self.post_process_config, sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), ) ##################### ##################### @@ -737,7 +755,10 @@ def on_yield(self, result: InferenceResult): opacity=0.8, ) - if result.instance_labels is not None: + if ( + len(result.instance_labels) > 0 + and self.worker_config.post_process_config.instance.enabled + ): for i, labels in enumerate(result.instance_labels): # labels = result.instance_labels method_name = ( @@ -779,5 +800,12 @@ def on_yield(self, result: InferenceResult): # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) + + if result.crf_results is not None: + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 462ee450..868dd279 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -13,6 +13,7 @@ ToInstanceUtils, ToSemanticUtils, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFWidget from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { @@ -22,6 +23,7 @@ "Convert to instance labels": ToInstanceUtils, "Convert to semantic labels": ToSemanticUtils, "Threshold": ThresholdUtils, + "CRF": CRFWidget, } @@ -30,7 +32,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): super().__init__() self._viewer = viewer - attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh"] + attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh", "crf"] self._create_utils_widgets(attr_names) # self.crop = Cropping(self._viewer) @@ -54,8 +56,15 @@ def __init__(self, viewer: "napari.viewer.Viewer"): def _build(self): layout = QVBoxLayout() ui.add_widgets(layout, self.utils_widgets) - layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) - layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) + ui.GroupedWidget.create_single_widget_group( + "Utilities", + widget=self.utils_choice, + layout=layout, + alignment=ui.BOTT_AL, + ) + + # layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) + # layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 2b38eb29..8a7c1565 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -137,6 +137,20 @@ class PostProcessConfig: instance: InstanceSegConfig = InstanceSegConfig() +@dataclass +class CRFConfig: + """ + Class to record params for CRF + """ + + sa: float = 10 + sb: float = 5 + sg: float = 1 + w1: float = 10 + w2: float = 5 + n_iters: int = 5 + + ################ # Inference configs @@ -196,6 +210,8 @@ class InferenceWorkerConfig: compute_stats: bool = False post_process_config: PostProcessConfig = PostProcessConfig() sliding_window_config: SlidingWindowConfig = SlidingWindowConfig() + use_crf: bool = False + crf_config: CRFConfig = CRFConfig() images_filepaths: str = None layer: napari.layers.Layer = None diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 36fc9aab..55e5abb3 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -57,6 +57,8 @@ """Alias for Qt.AlignmentFlag.AlignAbsolute, to use in addWidget""" BOTT_AL = Qt.AlignmentFlag.AlignBottom """Alias for Qt.AlignmentFlag.AlignBottom, to use in addWidget""" +TOP_AL = Qt.AlignmentFlag.AlignTop +"""Alias for Qt.AlignmentFlag.AlignTop, to use in addWidget""" ############### # colors dark_red = "#72071d" # crimson red @@ -65,6 +67,9 @@ napari_param_grey = "#414851" # napari parameters menu color (lighter gray) napari_param_darkgrey = "#202228" # napari default LineEdit color ############### +# dimensions for utils ScrollArea +UTILS_MAX_WIDTH = 300 +UTILS_MAX_HEIGHT = 500 logger = utils.LOGGER @@ -791,7 +796,7 @@ def layer_data(self): logger.warning("Please select a valid layer !") return - return self._viewer.layers[self.layer_name()].data + return self.layer().data class FilePathWidget(QWidget): # TODO include load as folder @@ -1277,12 +1282,20 @@ def set_layout(self): @classmethod def create_single_widget_group( - cls, title, widget, layout, l=7, t=20, r=7, b=11 + cls, + title, + widget, + layout, + l=7, + t=20, + r=7, + b=11, + alignment=LEFT_AL, ): group = cls(title, l, t, r, b) group.layout.addWidget(widget) group.setLayout(group.layout) - layout.addWidget(group) + layout.addWidget(group, alignment=alignment) def add_widgets(layout, widgets, alignment=LEFT_AL): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index f3fc09ba..e7eaf95a 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,11 +2,12 @@ from datetime import datetime from pathlib import Path +import napari import numpy as np from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian -from tifffile import imread as tfl_imread +from tifffile import imread, imwrite LOGGER = logging.getLogger(__name__) ############### @@ -21,6 +22,76 @@ """ +#################### +# viewer utils +def save_folder(results_path, folder_name, images, image_paths): + """ + Saves a list of images in a folder + + Args: + results_path: Path to the folder containing results + folder_name: Name of the folder containing results + images: List of images to save + image_paths: list of filenames of images + """ + results_folder = results_path / Path(folder_name) + results_folder.mkdir(exist_ok=False, parents=True) + + for file, image in zip(image_paths, images): + path = results_folder / Path(file).name + + imwrite( + path, + image, + ) + LOGGER.info(f"Saved processed folder as : {results_folder}") + + +def save_layer(results_path, image_name, image): + """ + Saves an image layer at the specified path + + Args: + results_path: path to folder containing result + image_name: image name for saving + image: data array containing image + + Returns: + + """ + path = str(results_path / Path(image_name)) # TODO flexible filetype + LOGGER.info(f"Saved as : {path}") + imwrite(path, image) + + +def show_result(viewer, layer, image, name): + """ + Adds layers to a viewer to show result to user + + Args: + viewer: viewer to add layer in + layer: original layer the operation was run on, to determine whether it should be an Image or Labels layer + image: the data array containing the image + name: name of the added layer + + Returns: + + """ + if isinstance(layer, napari.layers.Image): + LOGGER.debug("Added resulting image layer") + viewer.add_image(image, name=name) + elif isinstance(layer, napari.layers.Labels): + LOGGER.debug("Added resulting label layer") + viewer.add_labels(image, name=name) + else: + LOGGER.warning( + f"Results not shown, unsupported layer type {type(layer)}" + ) + + +#################### + + class Singleton(type): """ Singleton class that can only be instantiated once at a time, @@ -44,7 +115,7 @@ def __call__(cls, *args, **kwargs): # if filename == "tif": # return True # def read(self, data, **kwargs): -# return tfl_imread(data) +# return imread(data) # # def get_data(self, data): # return data, {} @@ -233,7 +304,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # logger.warning( + # LOGGER.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) @@ -470,9 +541,7 @@ def load_images( ) # images_original = dask_imread(filename_pattern_original) else: - images_original = tfl_imread( - filename_pattern_original - ) # tifffile imread + images_original = imread(filename_pattern_original) # tifffile imread return images_original From 79859b069aa531b71bca42d541cd9c9a45d8a928 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:37:33 +0200 Subject: [PATCH 384/577] Fix erroneous test comment, added toggle for crf - Warn if crf not installed - Fix test --- napari_cellseg3d/_tests/test_utils.py | 2 +- napari_cellseg3d/code_plugins/plugin_crf.py | 22 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index f2a9d32c..0b28183d 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -88,7 +88,7 @@ def test_get_padding_dim(): # "The padding value is currently 2048." # ) # - # pad = utils.get_padding_dim(size) + pad = utils.get_padding_dim(size) # # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 3dbd47bb..cbdacf3a 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -7,7 +7,11 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config +from napari_cellseg3d.code_models.crf import ( + CRF_INSTALLED, + CRFWorker, + crf_with_config, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.utils import LOGGER as logger @@ -43,6 +47,17 @@ def __init__(self, parent=None): self._set_tooltips() def _build(self): + if not CRF_INSTALLED: + ui.add_widgets( + self.layout, + [ + ui.make_label( + "ERROR: CRF not installed.\nPlease refer to the documentation to install it." + ), + ], + ) + self.set_layout() + return ui.add_widgets( self.layout, [ @@ -113,7 +128,10 @@ def __init__(self, viewer, parent=None): napari.layers.Image ) # to load all crf-compatible inputs, not int only self.image_layer_loader.setVisible(True) - self.start_button.setVisible(True) + if CRF_INSTALLED: + self.start_button.setVisible(True) + else: + self.start_button.setVisible(False) self.result_layer = None self.result_name = None From 0865b255b94188da1c887a37e013d30af3b76d18 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:56:08 +0200 Subject: [PATCH 385/577] Specify missing test deps --- pyproject.toml | 3 ++- tox.ini | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8e7187f5..5648ab40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ -"git+https://github.com/lucasb-eyer/pydensecrf.git", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -81,4 +81,5 @@ test = [ "coverage", "tox", "twine", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 87338cd8..a3eef589 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = magicgui pytest-qt qtpy + "git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From a887103869f6c5b34f50e75f91888b12fcbd4d71 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:02:31 +0200 Subject: [PATCH 386/577] Trying to fix deps on Git --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5648ab40..73fc862c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] diff --git a/tox.ini b/tox.ini index a3eef589..65a49bdd 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - "git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From a1742645d9f90c8537048fac1bcde62d6ae4b29f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:04:33 +0200 Subject: [PATCH 387/577] Removed master link to pydensecrf --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73fc862c..8d9d6bf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 65a49bdd..6f71b9db 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 00b772291bbfd70a5dc6f3af70781a930d16efdb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:07:23 +0200 Subject: [PATCH 388/577] Use commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d9d6bf4..0cc237e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] diff --git a/tox.ini b/tox.ini index 6f71b9db..5e0777f3 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From a8435f3bec9f00644a1b733047ab9c0e9a04a1b1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:09:27 +0200 Subject: [PATCH 389/577] Removed commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0cc237e5..09ed8585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] diff --git a/tox.ini b/tox.ini index 5e0777f3..3d7df5d0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 99c66f5fdb05e55aa67ecbc2b631e84bad940aa0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:11:27 +0200 Subject: [PATCH 390/577] Removed master --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 09ed8585..db39904b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] diff --git a/tox.ini b/tox.ini index 3d7df5d0..fd92727c 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From e326e95a35bc64ad4ef94ec04845060f55f2fa57 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:17:16 +0200 Subject: [PATCH 391/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index fd92727c..0a7c07f0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" + pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From 1691e6ee3ad26544dc6973ad64b45670bbb00a3b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 09:06:23 +0200 Subject: [PATCH 392/577] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db39904b..d223072a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] From e4b37bc1ccf03cf9fd68ed3f2774f4899acdee11 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 17:41:05 +0200 Subject: [PATCH 393/577] Fixes and improvements - More CRF info - Added error handling to scheduler rate - Added ETA to training - Updated padding warning trigger size --- napari_cellseg3d/code_models/crf.py | 30 ++++++++++------ napari_cellseg3d/code_models/model_workers.py | 34 ++++++++++++++----- .../code_models/models/model_VNet.py | 2 +- napari_cellseg3d/code_plugins/plugin_crf.py | 6 ++++ .../code_plugins/plugin_model_inference.py | 3 ++ .../code_plugins/plugin_model_training.py | 6 ++-- napari_cellseg3d/utils.py | 6 ++-- 7 files changed, 61 insertions(+), 26 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index a0146a5e..1b8dce28 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -33,6 +33,7 @@ from napari.qt.threading import GeneratorWorker from napari_cellseg3d.config import CRFConfig +from napari_cellseg3d.utils import LOGGER as logger __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ @@ -52,12 +53,16 @@ ] -def correct_shape_for_crf(image): - if len(image.shape) == 4: +def correct_shape_for_crf(image, desired_dims=4): + if len(image.shape) == desired_dims: return image - if len(image.shape) > 4: + if len(image.shape) > desired_dims: + if image.shape[0] > 1: + raise ValueError( + f"Image shape {image.shape} might have several channels" + ) image = np.squeeze(image, axis=0) - if len(image.shape) < 4: + if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) return correct_shape_for_crf(image) @@ -146,7 +151,7 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): ) -def crf_with_config(image, prob, config: CRFConfig = None): +def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info): if config is None: config = CRFConfig() if image.shape[-3:] != prob.shape[-3:]: @@ -156,6 +161,12 @@ def crf_with_config(image, prob, config: CRFConfig = None): ) image = correct_shape_for_crf(image) + prob = correct_shape_for_crf(prob) + + if log is not None: + log("Running CRF post-processing step") + log(f"Image shape : {image.shape}") + log(f"Labels shape : {prob.shape}") return crf( image, @@ -196,15 +207,12 @@ def _run_crf_job(self): raise ImportError("pydensecrf is not installed.") for image, labels in zip(self.images, self.labels): - if len(image.shape) == 3: - image = np.expand_dims(image, axis=0) - - if len(labels.shape) == 3: - labels = np.expand_dims(labels, axis=0) - if image.shape[-3:] != labels.shape[-3:]: raise ValueError("Image and labels must have the same shape.") + image = correct_shape_for_crf(image) + labels = correct_shape_for_crf(labels) + yield crf( image, labels, diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index c7196db7..39e6bb91 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1,4 +1,5 @@ import platform +import time import typing as t from dataclasses import dataclass from math import ceil @@ -598,7 +599,7 @@ def save_image( filetype = self.config.filetype else: original_filename = "_" - filetype = "" + filetype = ".tif" time = utils.get_date_time() @@ -712,12 +713,9 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): ) def run_crf(self, image, labels, image_id=0): - self.log(f"IMAGE SHAPE : {image.shape}") - self.log(f"LABEL SHAPE : {labels.shape}") - try: crf_results = crf_with_config( - image, labels, config=self.config.crf_config + image, labels, config=self.config.crf_config, log=self.log ) self.save_image( crf_results, i=image_id, additional_info="CRF", from_layer=True @@ -1152,6 +1150,8 @@ def train(self): weights_config = self.config.weights_info deterministic_config = self.config.deterministic_config + start_time = time.time() + try: if deterministic_config.enabled: set_determinism( @@ -1364,14 +1364,23 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) + + factor = self.config.scheduler_factor + if factor >= 1.0: + self.log(f"Warning : scheduler factor is {factor} >= 1.0") + self.log("Setting it to 0.5") + factor = 0.5 + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode="min", - factor=self.config.scheduler_factor, + factor=factor, patience=self.config.scheduler_patience, verbose=VERBOSE_SCHEDULER, ) - dice_metric = DiceMetric(include_background=True, reduction="mean") + dice_metric = DiceMetric( + include_background=False, reduction="mean" + ) best_metric = -1 best_metric_epoch = -1 @@ -1467,6 +1476,15 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] + self.log( + "ETA: " + + str( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 + ) + + "minutes" + ) if ( (epoch + 1) % self.config.validation_interval == 0 @@ -1490,7 +1508,7 @@ def train(self): overlap=0.25, sw_device=self.config.device, device=self.config.device, - progress=True, + progress=False, ) except Exception as e: self.raise_error(e, "Error during validation") diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 41554e80..7aa6476e 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=1, **kwargs): + def __init__(self, in_channels=1, out_channels=2, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index cbdacf3a..7ac605e9 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -178,6 +178,11 @@ def _build(self): def make_config(self): return self.crf_params_widget.make_config() + def print_config(self): + logger.info("CRF config:") + for item in self.make_config().__dict__.items(): + logger.info(f"{item[0]}: {item[1]}") + def _check_ready(self): if len(self.label_layer_loader.layer_list) < 1: logger.warning("No label layer loaded") @@ -272,6 +277,7 @@ def _on_start(self): def _on_finish(self): self.worker = None + self.start_button.setText("Start") def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 472cccd8..157b8af7 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -802,6 +802,9 @@ def on_yield(self, result: InferenceResult): # ) if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) viewer.add_image( result.crf_results, name=f"CRF_results_image_{image_id}", diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 88991f43..86d1d317 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -846,7 +846,7 @@ def start(self): loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), @@ -982,7 +982,7 @@ def on_yield(self, report: TrainingReport): layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(layer) else: @@ -993,7 +993,7 @@ def on_yield(self, report: TrainingReport): new_layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(new_layer) self.result_layers[i].data = report.images[i] diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index e7eaf95a..1aa316d2 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -12,8 +12,8 @@ LOGGER = logging.getLogger(__name__) ############### # Global logging level setting -# LOGGER.setLevel(logging.DEBUG) -LOGGER.setLevel(logging.INFO) +LOGGER.setLevel(logging.DEBUG) +# LOGGER.setLevel(logging.INFO) ############### """ utils.py @@ -311,7 +311,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): pad = 2**n n += 1 - if pad >= 256: + if pad >= 1024: LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." From 8b2fc1db138de9e373add506c79a6e0eb7c93c0a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 3 May 2023 09:57:34 +0200 Subject: [PATCH 394/577] Fixes and channel labeling prototype --- napari_cellseg3d/code_models/model_workers.py | 33 +++-- .../extract_extra_channels_labels.py | 124 ++++++++++++++++++ 2 files changed, 143 insertions(+), 14 deletions(-) create mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 39e6bb91..9f38a534 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -548,12 +548,14 @@ def create_inference_result( "A layer's ID should always be 0 (default value)" ) extra_dims = len(semantic_labels.shape) - 3 - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + if semantic_labels is not None: + semantic_labels = np.swapaxes( + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + if crf_results is not None: + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims + ) return InferenceResult( image_id=i + 1, @@ -1456,6 +1458,12 @@ def train(self): optimizer.zero_grad() outputs = model(inputs) # self.log(f"Output dimensions : {outputs.shape}") + if outputs.shape[1] > 1: + outputs = outputs[ + :, 1:, :, : + ] # FIXME fix channel number + if len(outputs.shape) < 4: + outputs = outputs.unsqueeze(0) loss = self.config.loss_function(outputs, labels) loss.backward() optimizer.step() @@ -1476,15 +1484,12 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] - self.log( - "ETA: " - + str( - (time.time() - start_time) - * (self.config.max_epochs / (epoch + 1) - 1) - / 60 - ) - + "minutes" + eta = ( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 ) + self.log("ETA: " + f"{eta:.2f}" + " minutes") if ( (epoch + 1) % self.config.validation_interval == 0 diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py new file mode 100644 index 00000000..2bd0a536 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -0,0 +1,124 @@ +import numpy as np +from skimage.filters import threshold_otsu +from skimage.segmentation import expand_labels +from tqdm import tqdm + + +def extract_labels_from_channels( + nucleus_labels: np.array, + extra_channels: list, + radius: int = 4, + threshold_factor=2, + viewer=None, +): + """ + Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. + Args: + nucleus_labels (np.array): labels for the nuclei + extra_channels (list): channels arrays to extract labels from + radius: radius in which the approximation is made + + Returns: + A list of extracted labels for each extra channel + """ + labeled_channels = {} + + contrasted_channels = [] + for channel in extra_channels: + channel = (channel - np.min(channel)) / ( + np.max(channel) - np.min(channel) + ) + threshold_brightness = threshold_otsu(channel) * threshold_factor + channel_contrasted = np.where( + channel > threshold_brightness, channel, 0 + ) + contrasted_channels.append(channel_contrasted) + if viewer is not None: + viewer.add_image( + channel_contrasted, + name="channel_contrasted", + colormap="viridis", + ) + for label_id in tqdm(np.unique(nucleus_labels)): + if label_id == 0: + continue + label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + expanded = expand_labels(label_nucleus, distance=radius) + for i, channel in enumerate(contrasted_channels): + label_contrasted = np.where(expanded != 0, channel, 0) + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels[ + f"label_{label_id}_channel_{i+1}" + ] = np.count_nonzero(labeled_channel) + if np.count_nonzero(labeled_channel) > 0 and viewer is not None: + print(np.count_nonzero(labeled_channel)) + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) + + return labeled_channels + + +if __name__ == "__main__": + from pathlib import Path + + import napari + import pandas as pd + from tifffile import imread + + image_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + ) + # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + nuclei_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" + ) + extra_channels_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" + ) + extra_channels = [ + imread(str(path)) + for path in extra_channels_path.glob( + "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" + ) + ] + labels = imread(str(image_path)) + viewer = napari.Viewer() + + shift = 0 + viewer.add_image( + imread(str(nuclei_path))[ + shift : 32 + shift, shift : 32 + shift, shift : 32 + shift + ], + name="nuclei", + ) + viewer.add_labels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + [ + viewer.add_image( + channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + for channel in extra_channels + ] + + labeled_channels = extract_labels_from_channels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], + [ + c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + for c in extra_channels + ], + radius=4, + viewer=viewer, + ) + table = pd.DataFrame( + labeled_channels.items(), columns=["name", "pixels count"] + ) + print(table) + # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] + # expanded = expand_labels(labels, 4) + # viewer.add_labels(expanded) + napari.run() From be5bce6523a3c7750cbe6439507d830a693b2dc0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 09:18:42 +0200 Subject: [PATCH 395/577] Fixes - Fixed multi-channel instance and csv stats - Fixed rotation of inference outputs - Raised max crop size --- napari_cellseg3d/code_models/model_workers.py | 74 ++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- .../code_plugins/plugin_model_inference.py | 79 +++++++++---------- .../extract_extra_channels_labels.py | 64 +++++++++------ napari_cellseg3d/interface.py | 54 ++++++++----- napari_cellseg3d/utils.py | 6 ++ 6 files changed, 160 insertions(+), 119 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 9f38a534..9e0a5085 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -547,15 +547,15 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - extra_dims = len(semantic_labels.shape) - 3 + if semantic_labels is not None: - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) + semantic_labels = utils.correct_rotation(semantic_labels) if crf_results is not None: - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + crf_results = utils.correct_rotation(crf_results) + if instance_labels is not None: + instance_labels = utils.correct_rotation( + instance_labels + ) # TODO(cyril) check if correct return InferenceResult( image_id=i + 1, @@ -581,8 +581,6 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): semantic_labels, i + 1, ) - if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -608,10 +606,11 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"{additional_info}_Prediction_{i+1}" + + f"{additional_info}" + + f"Prediction_{i+1}" + original_filename + self.config.model_info.name - + f"_{time}_" + + f"_{time}" + filetype ) try: @@ -638,18 +637,20 @@ def aniso_transform(self, image): return image def instance_seg( - self, to_instance, image_id=0, original_filename="layer", channel=None + self, semantic_labels, image_id=0, original_filename="layer" ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) - if channel is not None: - channel_id = f"_{channel}" + if len(semantic_labels.shape) == 4: + instance_labels = np.array( + [method.run_method(ch) for ch in semantic_labels] + ) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") else: - channel_id = "" + instance_labels = method.run_method(image=semantic_labels) if self.config.filetype == "": filetype = "" @@ -661,7 +662,6 @@ def instance_seg( + "/" + f"Instance_seg_labels_{image_id}_" + original_filename - + channel_id + "_" + self.config.model_info.name + f"_{utils.get_date_time()}" @@ -720,7 +720,10 @@ def run_crf(self, image, labels, image_id=0): image, labels, config=self.config.crf_config, log=self.log ) self.save_image( - crf_results, i=image_id, additional_info="CRF", from_layer=True + crf_results, + i=image_id, + additional_info="CRF_", + from_layer=True, ) return crf_results except ValueError as e: @@ -728,14 +731,17 @@ def run_crf(self, image, labels, image_id=0): return None def stats_csv(self, instance_labels): - if self.config.compute_stats: - stats = volume_stats(instance_labels) - return stats - - # except ValueError as e: - # self.log(f"Error occurred during stats computing : {e}") - # return None - else: + try: + if self.config.compute_stats: + if len(instance_labels.shape) == 4: + stats = [volume_stats(c) for c in instance_labels] + else: + stats = [volume_stats(instance_labels)] + return stats + else: + return None + except ValueError as e: + self.log(f"Error occurred during stats computing : {e}") return None def inference_on_layer(self, image, model, post_process_transforms): @@ -753,15 +759,9 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels_results = [] - stats_results = [] - - for channel in out: - instance_labels, stats = self.get_instance_result( - channel, from_layer=True - ) - instance_labels_results.append(instance_labels) - stats_results.append(stats) + instance_labels, stats = self.get_instance_result( + semantic_labels=out, from_layer=True + ) if self.config.use_crf: crf_results = self.run_crf(image, out) @@ -770,10 +770,10 @@ def inference_on_layer(self, image, model, post_process_transforms): return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels_results, + instance_labels=instance_labels, crf_results=crf_results, from_layer=True, - stats=stats_results, + stats=stats, ) # @thread_worker(connect={"errored": self.raise_error}) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 6e7f91f3..323f8068 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -80,7 +80,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.results_filewidget.check_ready() self.crop_size_widgets = ui.IntIncrementCounter.make_n( - 3, 1, 1000, DEFAULT_CROP_SIZE + 3, 1, 10000, DEFAULT_CROP_SIZE ) self.crop_size_labels = [ ui.make_label("Size in " + axis + " of cropped volume :", self) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 157b8af7..9f093629 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -140,7 +140,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.thresholding_slider = ui.Slider( - lower=1, default=config.PostProcessConfig().thresholding.threshold_value * 100, divide_factor=100.0, @@ -437,10 +436,10 @@ def _build(self): self.anisotropy_wdgt, # anisotropy self.thresholding_checkbox, self.thresholding_slider.container, # thresholding - self.use_instance_choice, - self.instance_widgets, self.use_crf, self.crf_widgets, + self.use_instance_choice, + self.instance_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -754,61 +753,61 @@ def on_yield(self, result: InferenceResult): name=f"pred_{image_id}_{model_name}", opacity=0.8, ) + if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) if ( len(result.instance_labels) > 0 and self.worker_config.post_process_config.instance.enabled ): - for i, labels in enumerate(result.instance_labels): - # labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(result.instance_labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(result.instance_labels, name=name) from napari_cellseg3d.utils import LOGGER as log - log.debug(f"len stats : {len(result.stats)}") + if result.stats is not None and isinstance( + result.stats, list + ): + log.debug(f"len stats : {len(result.stats)}") - for i, stats in enumerate(result.stats): - # stats = result.stats + for i, stats in enumerate(result.stats): + # stats = result.stats - if ( - self.worker_config.compute_stats - and stats is not None - ): - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances in channel {i} : {stats.number_objects[0]}" - ) + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) - - if result.crf_results is not None: - logger.debug( - f"CRF results shape : {result.crf_results.shape}" - ) - viewer.add_image( - result.crf_results, - name=f"CRF_results_image_{image_id}", - colormap="viridis", - ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py index 2bd0a536..70ee10b6 100644 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -4,8 +4,8 @@ from tqdm import tqdm -def extract_labels_from_channels( - nucleus_labels: np.array, +def extract_labels_from_channels( # TODO add separate channels results + nuclei_labels: np.array, extra_channels: list, radius: int = 4, threshold_factor=2, @@ -14,15 +14,14 @@ def extract_labels_from_channels( """ Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. Args: - nucleus_labels (np.array): labels for the nuclei + nuclei_labels (np.array): labels for the nuclei extra_channels (list): channels arrays to extract labels from radius: radius in which the approximation is made Returns: A list of extracted labels for each extra channel """ - labeled_channels = {} - + labeled_channels = [] contrasted_channels = [] for channel in extra_channels: channel = (channel - np.min(channel)) / ( @@ -39,31 +38,54 @@ def extract_labels_from_channels( name="channel_contrasted", colormap="viridis", ) - for label_id in tqdm(np.unique(nucleus_labels)): + for label_id in tqdm(np.unique(nuclei_labels)): if label_id == 0: continue - label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) expanded = expand_labels(label_nucleus, distance=radius) + restricted = np.where(expanded != 0, nuclei_labels, 0) + overlap = np.where(restricted != label_id, restricted, 0) + for i, channel in enumerate(contrasted_channels): label_contrasted = np.where(expanded != 0, channel, 0) - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels[ - f"label_{label_id}_channel_{i+1}" - ] = np.count_nonzero(labeled_channel) - if np.count_nonzero(labeled_channel) > 0 and viewer is not None: - print(np.count_nonzero(labeled_channel)) - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) + if overlap.any() != 0: + max_labeled = 0 + for overlap_id in np.unique(overlap): + if overlap_id == 0: + continue + assigned_pixels = np.count_nonzero( + np.where(overlap == overlap_id, channel, 0) + ) + if assigned_pixels > max_labeled: + max_labeled = assigned_pixels + max_label_id = overlap_id + if label_id != max_label_id: + labeled_channels.append( + np.zeros_like(label_contrasted) + ) + else: + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels.append(labeled_channel) + if ( + np.count_nonzero(labeled_channel) > 0 + and viewer is not None + ): + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) - return labeled_channels + cat_labels = np.zeros_like(nuclei_labels) + for labels in np.unique(labeled_channels): + if labels == 0: + continue + cat_labels += np.where(labels != 0, labels, 0) + return cat_labels if __name__ == "__main__": from pathlib import Path import napari - import pandas as pd from tifffile import imread image_path = ( @@ -114,10 +136,8 @@ def extract_labels_from_channels( radius=4, viewer=viewer, ) - table = pd.DataFrame( - labeled_channels.items(), columns=["name", "pixels count"] - ) - print(table) + + viewer.add_labels(labeled_channels) # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] # expanded = expand_labels(labels, 4) # viewer.add_labels(expanded) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 55e5abb3..9d06863e 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -469,6 +469,11 @@ def __init__( ): super().__init__(orientation, parent) + if upper <= lower: + raise ValueError( + "The minimum value cannot be below the maximum one" + ) + self.setMaximum(upper) self.setMinimum(lower) self.setSingleStep(step) @@ -544,23 +549,29 @@ def _warn_outside_bounds(self, default): def _update_slider(self): """Update slider when value is changed""" - if self._value_label.text() == "": - return + try: + if self._value_label.text() == "": + return - value = float(self._value_label.text()) * self._divide_factor + value = float(self._value_label.text()) * self._divide_factor - if value < self.minimum(): - self.slider_value = self.minimum() - return - if value > self.maximum(): - self.slider_value = self.maximum() - return + if value < self.minimum(): + self.slider_value = self.minimum() + return + if value > self.maximum(): + self.slider_value = self.maximum() + return - self.slider_value = value + self.slider_value = value + except Exception as e: + logger.error(e) def _update_value_label(self): """Update label, to connect to when slider is dragged""" - self._value_label.setText(str(self.value_text)) + try: + self._value_label.setText(str(self.value_text)) + except Exception as e: + logger.error(e) @property def tooltips(self): @@ -596,16 +607,21 @@ def value_text(self): def slider_value(self, value: int): """Set a value (int) divided by self._divide_factor""" if value < self.minimum() or value > self.maximum(): - raise ValueError( - f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + logger.error( + ValueError( + f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + ) ) - self.setValue(int(value)) - - divided = value / self._divide_factor - if self._divide_factor == 1.0: - divided = int(divided) - self._value_label.setText(str(divided)) + try: + self.setValue(int(value)) + + divided = value / self._divide_factor + if self._divide_factor == 1.0: + divided = int(divided) + self._value_label.setText(str(divided)) + except Exception as e: + logger.error(e) class AnisotropyWidgets(QWidget): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 1aa316d2..75c9734e 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -202,6 +202,12 @@ def dice_coeff(y_true, y_pred): return score +def correct_rotation(image): + """Rotates the exes 0 and 2 in [DHW] section of image array""" + extra_dims = len(image) - 3 + return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, From db435d0d1d77b87a59c82b5a90684d68490e8ad6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 14:42:02 +0200 Subject: [PATCH 396/577] Update plugin_model_inference.py --- napari_cellseg3d/code_plugins/plugin_model_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 9f093629..df64a625 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -762,9 +762,8 @@ def on_yield(self, result: InferenceResult): name=f"CRF_results_image_{image_id}", colormap="viridis", ) - if ( - len(result.instance_labels) > 0 + result.instance_labels is not None and self.worker_config.post_process_config.instance.enabled ): method_name = ( From bc26463b1dda25b26cf4eb5445516aa891691800 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 6 May 2023 09:56:17 +0200 Subject: [PATCH 397/577] Update plugin_crop.py --- napari_cellseg3d/code_plugins/plugin_crop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 323f8068..46c2cfb2 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -49,7 +49,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.label_layer_loader.layer_list.label.setText("Image 2") self.crop_second_image_choice = ui.CheckBox( - "Crop another\nimage simultaneously", + "Crop another\nimage/label simultaneously", ) self.crop_second_image_choice.toggled.connect( self._toggle_second_image_io_visibility From 77dbc9beec1b0cdeabe22776fe77db23cff00c90 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 10:16:58 +0200 Subject: [PATCH 398/577] Fixed patch_func sample number mismatch --- napari_cellseg3d/code_models/model_workers.py | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 9e0a5085..88d374a0 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1230,30 +1230,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if do_sampling: - sample_loader = Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=self.config.num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), - ), - EnsureTyped(keys=["image", "label"]), - ] - ) if self.config.do_augmentation: train_transforms = ( @@ -1285,6 +1261,31 @@ def train(self): ] ) # self.log("Loading dataset...\n") + def get_loader_func(num_samples): + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) + ), + ), + EnsureTyped(keys=["image", "label"]), + ] + ) + if do_sampling: # if there is only one volume, split samples # TODO(cyril) : maybe implement something in user config to toggle this behavior @@ -1297,11 +1298,17 @@ def train(self): self.config.num_samples * (1 - self.config.validation_percent) ) + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) else: num_train_samples = ( num_val_samples ) = self.config.num_samples + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) + + logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" @@ -1311,14 +1318,14 @@ def train(self): train_ds = PatchDataset( data=self.train_files, transform=train_transforms, - patch_func=sample_loader, + patch_func=sample_loader_train, samples_per_image=num_train_samples, ) logger.debug("val_ds") val_ds = PatchDataset( data=self.val_files, transform=val_transforms, - patch_func=sample_loader, + patch_func=sample_loader_eval, samples_per_image=num_val_samples, ) From ae5d7c5fca2952f9f3db7649ad02bb9f78e8eaf2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:08:52 +0200 Subject: [PATCH 399/577] Testing relabel tools --- napari_cellseg3d/dev_scripts/correct_labels.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2ab60332..9862c3fa 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -367,8 +367,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") +# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") # -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# image_path = str(im_path / "volumes/images.tif") +# gt_labels_path = str(im_path / "labels/testing_im.tif") +# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) From 9b71743b408567eaddfcceb0c33b77ffccf2349c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:38:45 +0200 Subject: [PATCH 400/577] Fixes in inference --- napari_cellseg3d/code_models/model_workers.py | 2 ++ napari_cellseg3d/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 88d374a0..754a5007 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -504,6 +504,8 @@ def model_output_wrapper(inputs): sw_device=self.config.device, device=dataset_device, overlap=window_overlap, + mode="gaussian", + sigma_scale=0.01, progress=True, ) except Exception as e: diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 75c9734e..86754ad0 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -204,7 +204,7 @@ def dice_coeff(y_true, y_pred): def correct_rotation(image): """Rotates the exes 0 and 2 in [DHW] section of image array""" - extra_dims = len(image) - 3 + extra_dims = len(image.shape) - 3 return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) From 61e7a6e3746ca6fcf09f60d638d52345a66c0977 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 14:48:14 +0200 Subject: [PATCH 401/577] add model template + fix test + wnet loading opti - test fixes - changed crf input reqs - adapted instance seg for several channels --- napari_cellseg3d/_tests/test_models.py | 10 ++- .../_tests/test_plugin_inference.py | 11 ++-- napari_cellseg3d/_tests/test_training.py | 11 ++-- napari_cellseg3d/code_models/crf.py | 11 ++-- .../code_models/model_instance_seg.py | 29 ++++++++- napari_cellseg3d/code_models/model_workers.py | 62 +++++++++---------- .../code_models/models/TEMPLATE_model.py | 20 ++++++ .../code_models/models/model_SwinUNetR.py | 13 +++- .../code_models/models/model_WNet.py | 19 ++++++ .../code_plugins/plugin_convert.py | 2 +- 10 files changed, 129 insertions(+), 59 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/TEMPLATE_model.py diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 1fc15872..35af8c76 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -15,6 +15,8 @@ def test_correct_shape_for_crf(): def test_model_list(): for model_name in MODEL_LIST.keys(): + # if model_name=="test": + # continue dims = 128 test = MODEL_LIST[model_name]( input_img_size=[dims, dims, dims], @@ -39,18 +41,20 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) - # assert res > 0 + assert 0 <= res <= 1 def test_crf(qtbot): dims = 8 mock_image = np.random.rand(1, dims, dims, dims) mock_label = np.random.rand(2, dims, dims, dims) - - crf = CRFWorker(mock_image, mock_label) + assert len(mock_label.shape) == 4 + crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) + assert len(result.shape) == 4 + assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] crf.yielded.connect(on_yield) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 66c50fba..3dafeabc 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,9 +3,10 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_inference(make_napari_viewer, qtbot): @@ -28,9 +29,9 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.setCurrentIndex(-1) # widget.start() # takes too long on Github Actions # assert widget.worker is not None diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 21731ba1..921a6d26 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,9 +2,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_training(make_napari_viewer, qtbot): @@ -32,9 +33,9 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) # widget.start() # assert widget.worker is not None diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 1b8dce28..21caf35f 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -57,10 +57,10 @@ def correct_shape_for_crf(image, desired_dims=4): if len(image.shape) == desired_dims: return image if len(image.shape) > desired_dims: - if image.shape[0] > 1: - raise ValueError( - f"Image shape {image.shape} might have several channels" - ) + # if image.shape[0] > 1: + # raise ValueError( + # f"Image shape {image.shape} might have several channels" + # ) image = np.squeeze(image, axis=0) if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) @@ -200,7 +200,6 @@ def __init__( self.config = config self.log = log - # TODO(cyril) : add progress bar into log ? or do it in inference def _run_crf_job(self): """Runs the CRF post-processing step for the W-Net.""" if not CRF_INSTALLED: @@ -211,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + # labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py index d1a03eec..0c3c6c6b 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,3 +1,4 @@ +import abc from dataclasses import dataclass from functools import partial from typing import List @@ -82,8 +83,32 @@ def __init__( ) self.counters.append(getattr(self, widget)) + @abc.abstractmethod def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") + raise NotImplementedError() + + def _make_list_from_channels( + self, image + ): # TODO(cyril) : adapt to batch dimension + if len(image.shape) > 4: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" + ) + if len(image.shape) == 4: + image = np.squeeze(image) + if len(image.shape) == 4: + return [im for im in image] + elif len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) + else: + return [image] + + def run_method_on_channels(self, image): + image_list = self._make_list_from_channels(image) # FIXME rename + result = np.array([self.run_method(im) for im in image_list]) + return result.squeeze() class InstanceMethod: @@ -611,7 +636,7 @@ def run_method(self, volume): """ method = self.methods[self.method_choice.currentText()] - return method.run_method(volume) + return method.run_method_on_channels(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 754a5007..93f9908b 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -645,17 +645,11 @@ def instance_seg( self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - - if len(semantic_labels.shape) == 4: - instance_labels = np.array( - [method.run_method(ch) for ch in semantic_labels] - ) - self.log(f"DEBUG instance results shape : {instance_labels.shape}") - else: - instance_labels = method.run_method(image=semantic_labels) + instance_labels = method.run_method_on_channels(semantic_labels) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") if self.config.filetype == "": - filetype = "" + filetype = ".tif" else: filetype = "_" + self.config.filetype @@ -855,7 +849,8 @@ def inference(self): weights = str( PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - model.load_state_dict( + + model.load_state_dict( # note that this is redefined in WNet_ torch.load( weights, map_location=self.config.device, @@ -1232,7 +1227,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if self.config.do_augmentation: train_transforms = ( Compose( # TODO : figure out which ones and values ? @@ -1262,31 +1256,32 @@ def train(self): EnsureTyped(keys=["image", "label"]), ] ) + # self.log("Loading dataset...\n") def get_loader_func(num_samples): - return Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) ), - EnsureTyped(keys=["image", "label"]), - ] - ) + ), + EnsureTyped(keys=["image", "label"]), + ] + ) if do_sampling: # if there is only one volume, split samples @@ -1310,7 +1305,6 @@ def get_loader_func(num_samples): sample_loader_train = get_loader_func(num_train_samples) sample_loader_eval = get_loader_func(num_val_samples) - logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" diff --git a/napari_cellseg3d/code_models/models/TEMPLATE_model.py b/napari_cellseg3d/code_models/models/TEMPLATE_model.py new file mode 100644 index 00000000..f68e5f4f --- /dev/null +++ b/napari_cellseg3d/code_models/models/TEMPLATE_model.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class ModelTemplate_(ABC): + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 05819e22..484890d1 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -9,12 +9,19 @@ class SwinUNETR_(SwinUNETR): use_default_training = True weights_file = "Swin64_best_metric.pth" - def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + def __init__( + self, + in_channels=1, + out_channels=1, + input_img_size=128, + use_checkpoint=True, + **kwargs, + ): try: super().__init__( input_img_size, - in_channels=1, - out_channels=1, + in_channels=in_channels, + out_channels=out_channels, feature_size=48, use_checkpoint=use_checkpoint, **kwargs, diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 4a9ff70d..86a1f7e6 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,5 +1,12 @@ +from typing import TypeVar + +from torch.nn import Module + +# local from napari_cellseg3d.code_models.models.wnet.model import WNet +T = TypeVar("T", bound="Module") + class WNet_(WNet): use_default_training = False @@ -20,6 +27,9 @@ def __init__( num_classes=num_classes, ) + def train(self: T, mode: bool = True) -> T: + raise NotImplementedError("Training not implemented for WNet") + def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. @@ -27,3 +37,12 @@ def forward(self, x): enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc + + def load_state_dict(self, state_dict, strict=False): + """Load the model state dict for inference, without the decoder weights.""" + encoder_checkpoint = state_dict.copy() + for k in state_dict.keys(): + if k.startswith("decoder"): + encoder_checkpoint.pop(k) + # print(encoder_checkpoint.keys()) + super().load_state_dict(encoder_checkpoint, strict=strict) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index f7b476d0..8353632e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -363,7 +363,7 @@ def _start(self): elif self.folder_choice.isChecked(): if len(self.images_filepaths) != 0: images = [ - self.instance_widgets.run_method(imread(file)) + self.instance_widgets.run_method_on_channels(imread(file)) for file in self.images_filepaths ] utils.save_folder( From 32e703f88a2d30d6b304dc7961989db3e8c0b5b9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 15:16:25 +0200 Subject: [PATCH 402/577] Update model_WNet.py --- napari_cellseg3d/code_models/models/model_WNet.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 86a1f7e6..f07ac517 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,12 +1,6 @@ -from typing import TypeVar - -from torch.nn import Module - # local from napari_cellseg3d.code_models.models.wnet.model import WNet -T = TypeVar("T", bound="Module") - class WNet_(WNet): use_default_training = False @@ -27,8 +21,8 @@ def __init__( num_classes=num_classes, ) - def train(self: T, mode: bool = True) -> T: - raise NotImplementedError("Training not implemented for WNet") + # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): """Forward ENCODER pass of the W-Net model. From 519c0501891812894e34f738d10d69bd8e0f5d79 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 13 May 2023 10:29:39 +0200 Subject: [PATCH 403/577] Update model_VNet.py --- napari_cellseg3d/code_models/models/model_VNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 7aa6476e..41554e80 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=2, **kwargs): + def __init__(self, in_channels=1, out_channels=1, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs From 6dff8ae428b951057a56c77dcdd56d8d743c1109 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:51:02 +0200 Subject: [PATCH 404/577] Fixed folder creation when saving to folder --- napari_cellseg3d/code_models/crf.py | 2 +- napari_cellseg3d/code_plugins/plugin_convert.py | 10 +++++----- napari_cellseg3d/code_plugins/plugin_crf.py | 2 +- napari_cellseg3d/utils.py | 3 +++ 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 21caf35f..aa9cce75 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -210,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - # labels = correct_shape_for_crf(labels) + labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 8353632e..77aa9af6 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -46,7 +46,7 @@ def __init__(self, viewer: "napari.Viewer.viewer", parent=None): self.aniso_widgets = ui.AnisotropyWidgets(self, always_visible=True) self.start_btn = ui.Button("Start", self._start) - self.results_path = Path.home() / Path("cellseg3d/anisotropy") + self.results_path = str(Path.home() / Path("cellseg3d/anisotropy")) self.results_filewidget.text_field.setText(str(self.results_path)) self.results_filewidget.check_ready() @@ -76,7 +76,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + utils.mkdir_from_str(self.results_path) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -175,7 +175,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -342,7 +342,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -436,7 +436,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.binarize_counter.value() if self.layer_choice: diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 7ac605e9..d8407a0f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -238,7 +238,7 @@ def _start(self): self.result_layer = self.label_layer_loader.layer() self.result_name = self.label_layer_loader.layer_name() - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) image_list = [self.image_layer_loader.layer_data()] labels_list = [self.label_layer_loader.layer_data()] diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 86754ad0..6e2f7341 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -131,6 +131,9 @@ def normalize_x(image): image = image / 127.5 - 1 return image +def mkdir_from_str(path: str, exist_ok=True, parents=True): + Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) + def normalize_y(image): """Normalizes the values of an image array to be between [0;1] rather than [0;255] From 43a78496b0319a536655f6612d5ee886619f5bd6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:54:07 +0200 Subject: [PATCH 405/577] Fix check_ready for results filewidget --- napari_cellseg3d/interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 9d06863e..6c5eb5c3 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -854,6 +854,9 @@ def __init__( self.build() self.check_ready() + if self._required: + self._text_field.textChanged.connect(self.check_ready) + def build(self): """Builds the layout of the widget""" add_widgets( @@ -912,7 +915,7 @@ def required(self, is_required): try: self.text_field.textChanged.disconnect(self.check_ready) except TypeError: - return + pass self.check_ready() self._required = is_required From e3131222bda4c403d09d3610cdd0768a7afb59a2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 11:28:33 +0200 Subject: [PATCH 406/577] Added remapping in WNet + ruff config --- .pre-commit-config.yaml | 3 ++ napari_cellseg3d/code_models/model_workers.py | 51 ++++++++----------- napari_cellseg3d/utils.py | 48 ++++++++++++----- pyproject.toml | 36 ++++++++++++- 4 files changed, 93 insertions(+), 45 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61ecaae5..f9fe2853 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,9 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + - id: check-toml # - repo: https://github.com/pycqa/isort # rev: 5.12.0 # hooks: diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index 93f9908b..4ce4d180 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -119,7 +119,7 @@ def show_progress(_, block_size, __): # count, block_size, total_size logger.info(message) return - with open(json_path) as f: + with Path.open(json_path) as f: neturls = json.load(f) if model_name in neturls: url = neturls[model_name] @@ -259,8 +259,7 @@ def create_inference_dict(images_filepaths): Returns: dict: list of image paths from loaded folder""" - data_dicts = [{"image": image_name} for image_name in images_filepaths] - return data_dicts + return [{"image": image_name} for image_name in images_filepaths] def set_download_log(self, widget): self.downloader.log_widget = widget @@ -472,10 +471,9 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - if self.config.keep_on_cpu: - dataset_device = "cpu" - else: - dataset_device = self.config.device + dataset_device = ( + "cpu" if self.config.keep_on_cpu else self.config.device + ) if self.config.sliding_window_config.is_enabled(): window_size = self.config.sliding_window_config.window_size @@ -492,6 +490,7 @@ def model_output( # outputs = model(inputs) def model_output_wrapper(inputs): + inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) @@ -509,7 +508,7 @@ def model_output_wrapper(inputs): progress=True, ) except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.debug("failed to run sliding window inference") self.raise_error(e, "Error during sliding window inference") logger.debug(f"Inference output shape: {outputs.shape}") @@ -520,11 +519,9 @@ def model_output_wrapper(inputs): if post_process: out = np.array(out).astype(np.float32) out = np.squeeze(out) - return out - else: - return out + return out except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Error during sliding window inference") # sys.stdout = old_stdout # sys.stderr = old_stderr @@ -635,8 +632,7 @@ def aniso_transform(self, image): padding_mode="empty", ) return anisotropic_transform(image[0]) - else: - return image + return image def instance_seg( self, semantic_labels, image_id=0, original_filename="layer" @@ -648,10 +644,11 @@ def instance_seg( instance_labels = method.run_method_on_channels(semantic_labels) self.log(f"DEBUG instance results shape : {instance_labels.shape}") - if self.config.filetype == "": - filetype = ".tif" - else: - filetype = "_" + self.config.filetype + filetype = ( + ".tif" + if self.config.filetype == "" + else "_" + self.config.filetype + ) instance_filepath = ( self.config.results_path @@ -733,9 +730,9 @@ def stats_csv(self, instance_labels): stats = [volume_stats(c) for c in instance_labels] else: stats = [volume_stats(instance_labels)] - return stats else: - return None + stats = None + return stats except ValueError as e: self.log(f"Error occurred during stats computing : {e}") return None @@ -759,10 +756,7 @@ def inference_on_layer(self, image, model, post_process_transforms): semantic_labels=out, from_layer=True ) - if self.config.use_crf: - crf_results = self.run_crf(image, out) - else: - crf_results = None + crf_results = self.run_crf(image, out) if self.config.use_crf else None return self.create_inference_result( semantic_labels=out, @@ -944,7 +938,7 @@ def inference(self): model.to("cpu") # self.quit() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Inference failed") self.quit() finally: @@ -1175,10 +1169,7 @@ def train(self): do_sampling = self.config.sampling - if do_sampling: - size = self.config.sample_size - else: - size = check + size = self.config.sample_size if do_sampling else check model = model_class( # FIXME check if correct input_img_size=utils.get_padding_dim(size), use_checkpoint=True @@ -1411,7 +1402,7 @@ def get_loader_func(num_samples): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") - logger.error(e, exc_info=True) + logger.exception(e) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 6e2f7341..7ca29e00 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING, Union import napari import numpy as np @@ -9,6 +10,9 @@ from skimage.filters import gaussian from tifffile import imread, imwrite +if TYPE_CHECKING: + import torch + LOGGER = logging.getLogger(__name__) ############### # Global logging level setting @@ -128,8 +132,8 @@ def normalize_x(image): Returns: array: normalized value for the image """ - image = image / 127.5 - 1 - return image + return image / 127.5 - 1 + def mkdir_from_str(path: str, exist_ok=True, parents=True): Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) @@ -144,8 +148,7 @@ def normalize_y(image): Returns: array: normalized value for the image """ - image = image / 255 - return image + return image / 255 def sphericity_volume_area(volume, surface_area): @@ -199,10 +202,9 @@ def dice_coeff(y_true, y_pred): y_true_f = y_true.flatten() y_pred_f = y_pred.flatten() intersection = np.sum(y_true_f * y_pred_f) - score = (2.0 * intersection + smooth) / ( + return (2.0 * intersection + smooth) / ( np.sum(y_true_f) + np.sum(y_pred_f) + smooth ) - return score def correct_rotation(image): @@ -211,6 +213,27 @@ def correct_rotation(image): return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) +def normalize_max(image): + """Normalizes an image using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image.reshape(shape) + return image + + +def remap_image( + image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 +): + """Normalizes a numpy array or Tensor using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image * (new_max - new_min) + new_min + image = image.reshape(shape) + return image + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, @@ -276,10 +299,11 @@ def time_difference(time_start, time_finish, as_string=True): minutes = f"{int(minutes[0])}".zfill(2) seconds = f"{int(seconds[0])}".zfill(2) - if as_string: - return f"{hours}:{minutes}:{seconds}" - else: - return [hours, minutes, seconds] + return ( + f"{hours}:{minutes}:{seconds}" + if as_string + else [hours, minutes, seconds] + ) def get_padding_dim(image_shape, anisotropy_factor=None): @@ -549,10 +573,8 @@ def load_images( "Loading as folder not implemented yet. Use napari to load as folder" ) # images_original = dask_imread(filename_pattern_original) - else: - images_original = imread(filename_pattern_original) # tifffile imread - return images_original + return imread(filename_pattern_original) # tifffile imread # def load_predicted_masks(mito_mask_dir, er_mask_dir, filetype): diff --git a/pyproject.toml b/pyproject.toml index d223072a..81d2a788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,11 +46,43 @@ where = ["."] [tool.ruff] select = [ "E", "F", "W", - "I", + "A", "B", + "G", + "I", + "PT", + "PTH", + "RET", + "SIM", + "TCH", + "NPY", ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) -ignore = ["E501", "E741"] +# and 'G004' (do not use f-strings in logging) +ignore = ["E501", "E741", "G004"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] [tool.black] line-length = 79 From 6e0d4e311054e498c65b3cc16cb64d442137144f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 13:21:06 +0200 Subject: [PATCH 407/577] Run new hooks --- napari_cellseg3d/_tests/test_models.py | 13 +- .../_tests/test_weight_download.py | 2 +- napari_cellseg3d/code_models/crf.py | 25 ++-- ...stance_seg.py => instance_segmentation.py} | 19 ++- .../code_models/model_framework.py | 11 +- .../code_models/models/model_SwinUNetR.py | 2 +- .../code_models/models/model_TRAILMAP_MS.py | 2 +- .../code_models/models/model_WNet.py | 8 +- .../code_models/models/unet/buildingblocks.py | 3 +- .../code_models/models/wnet/soft_Ncuts.py | 4 +- .../{model_workers.py => workers.py} | 2 +- .../code_plugins/plugin_convert.py | 127 +++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 6 +- .../code_plugins/plugin_metrics.py | 12 +- .../code_plugins/plugin_model_inference.py | 11 +- .../code_plugins/plugin_model_training.py | 18 +-- .../code_plugins/plugin_review.py | 11 +- .../code_plugins/plugin_review_dock.py | 5 +- napari_cellseg3d/config.py | 8 +- .../dev_scripts/artefact_labeling.py | 16 +-- .../dev_scripts/correct_labels.py | 7 +- napari_cellseg3d/interface.py | 58 ++++---- pyproject.toml | 2 + 23 files changed, 191 insertions(+), 181 deletions(-) rename napari_cellseg3d/code_models/{model_instance_seg.py => instance_segmentation.py} (99%) rename napari_cellseg3d/code_models/{model_workers.py => workers.py} (99%) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35af8c76..35174b85 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,20 +1,23 @@ import numpy as np import torch +from numpy.random import PCG64, Generator from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +rand_gen = Generator(PCG64(12345)) + def test_correct_shape_for_crf(): - test = np.random.rand(1, 1, 8, 8, 8) + test = rand_gen.random(size=(1, 1, 8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) - test = np.random.rand(8, 8, 8) + test = rand_gen.random(size=(8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) def test_model_list(): - for model_name in MODEL_LIST.keys(): + for model_name in MODEL_LIST: # if model_name=="test": # continue dims = 128 @@ -46,8 +49,8 @@ def test_soft_ncuts_loss(): def test_crf(qtbot): dims = 8 - mock_image = np.random.rand(1, dims, dims, dims) - mock_label = np.random.rand(2, dims, dims, dims) + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index b9d4abe5..be694d99 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( PRETRAINED_WEIGHTS_DIR, WeightsDownloader, ) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index aa9cce75..8c311059 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,17 +54,15 @@ def correct_shape_for_crf(image, desired_dims=4): - if len(image.shape) == desired_dims: - return image if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( # f"Image shape {image.shape} might have several channels" # ) image = np.squeeze(image, axis=0) - if len(image.shape) < desired_dims: + elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) - return correct_shape_for_crf(image) + return image def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): @@ -185,8 +183,8 @@ class CRFWorker(GeneratorWorker): def __init__( self, - images_list, - labels_list, + images_list: list, + labels_list: list, config: CRFConfig = None, log=None, ): @@ -205,16 +203,19 @@ def _run_crf_job(self): if not CRF_INSTALLED: raise ImportError("pydensecrf is not installed.") - for image, labels in zip(self.images, self.labels): - if image.shape[-3:] != labels.shape[-3:]: + if len(self.images) != len(self.labels): + raise ValueError("Number of images and labels must be the same.") + + for i in range(len(self.images)): + if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + im = correct_shape_for_crf(self.labels[i]) + prob = correct_shape_for_crf(self.labels[i]) yield crf( - image, - labels, + im, + prob, self.config.sa, self.config.sb, self.config.sg, diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/instance_segmentation.py similarity index 99% rename from napari_cellseg3d/code_models/model_instance_seg.py rename to napari_cellseg3d/code_models/instance_segmentation.py index 0c3c6c6b..5de7ab0c 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -94,16 +94,16 @@ def _make_list_from_channels( raise ValueError( f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" ) + if len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) if len(image.shape) == 4: image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - elif len(image.shape) < 2: - raise ValueError( - f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" - ) - else: return [image] + return None def run_method_on_channels(self, image): image_list = self._make_list_from_channels(image) # FIXME rename @@ -353,12 +353,10 @@ def to_instance(image, is_file_path=False): image = [imread(image)] # image = image.compute() - result = binary_watershed( + return binary_watershed( image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 ) # FIXME add params from utils plugin - return result - def to_semantic(image, is_file_path=False): """Converts a **ground-truth** label to semantic (binary 0/1) labels. @@ -375,8 +373,7 @@ def to_semantic(image, is_file_path=False): # image = image.compute() image[image >= 1] = 1 - result = image.astype(np.uint16) - return result + return image.astype(np.uint16) def volume_stats(volume_image): @@ -620,7 +617,7 @@ def _build(self): self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): + for name in self.instance_widgets: if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 37fc6a49..60644916 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,8 +1,11 @@ from pathlib import Path +from typing import TYPE_CHECKING -import napari import torch +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QProgressBar, QSizePolicy @@ -126,7 +129,7 @@ def save_log(self): path = self.results_path if len(log) != 0: - with open( + with Path.open( path + f"/Log_report_{utils.get_date_time()}.txt", "x", ) as f: @@ -152,8 +155,8 @@ def save_log_to_path(self, path): ) if len(log) != 0: - with open( - path, + with Path.open( + Path(path), "x", ) as f: f.write(log) diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 484890d1..2d7b5ef6 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -27,7 +27,7 @@ def __init__( **kwargs, ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( input_img_size, in_channels=1, diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 1123173a..baf8635d 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -16,7 +16,7 @@ def __init__(self, in_channels=1, out_channels=1, **kwargs): in_channels=in_channels, out_channels=out_channels, **kwargs ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( in_channels=in_channels, out_channels=out_channels ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index f07ac517..7235bd61 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -28,14 +28,14 @@ def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. """ - enc = self.forward_encoder(x) - # dec = self.forward_decoder(enc) - return enc + return self.forward_encoder(x) + # enc = self.forward_encoder(x) + # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" encoder_checkpoint = state_dict.copy() - for k in state_dict.keys(): + for k in state_dict: if k.startswith("decoder"): encoder_checkpoint.pop(k) # print(encoder_checkpoint.keys()) diff --git a/napari_cellseg3d/code_models/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py index 73913ab8..ce7d378f 100644 --- a/napari_cellseg3d/code_models/models/unet/buildingblocks.py +++ b/napari_cellseg3d/code_models/models/unet/buildingblocks.py @@ -422,8 +422,7 @@ def forward(self, encoder_features, x): def _joining(encoder_features, x, concat): if concat: return torch.cat((encoder_features, x), dim=1) - else: - return encoder_features + x + return encoder_features + x def create_encoders( diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 4e84579f..938292c2 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -206,6 +206,7 @@ def forward(self, labels, inputs): return torch.add(torch.neg(loss), K) """ + return None def gaussian_kernel(self, radius, sigma): """Computes the Gaussian kernel. @@ -348,5 +349,4 @@ def get_weights(self, inputs): 1, 1, self.W_X.shape[0], self.W_X.shape[1] ) # (1, 1, H*W*D, H*W*D) - W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) - return W + return torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/workers.py similarity index 99% rename from napari_cellseg3d/code_models/model_workers.py rename to napari_cellseg3d/code_models/workers.py index 4ce4d180..c1ed62fd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -54,7 +54,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.crf import crf_with_config -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( ImageStats, volume_stats, ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 77aa9af6..4357e51e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -7,7 +7,7 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceWidgets, clear_small_objects, threshold, @@ -98,18 +98,19 @@ def _start(self): f"isotropic_{layer.name}", ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -193,18 +194,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -274,18 +276,19 @@ def _start(self): utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): @@ -360,18 +363,19 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.instance_widgets.run_method_on_channels(imread(file)) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.instance_widgets.run_method_on_channels(imread(file)) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -454,18 +458,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 46c2cfb2..a27b4baa 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -157,8 +157,10 @@ def _build(self): dim_group_l.addWidget(self.aniso_widgets) [ dim_group_l.addWidget(widget, alignment=ui.ABS_AL) - for list in zip(self.crop_size_labels, self.crop_size_widgets) - for widget in list + for widget_list in zip( + self.crop_size_labels, self.crop_size_widgets + ) + for widget in widget_list ] dim_group_w.setLayout(dim_group_l) layout.addWidget(dim_group_w) diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 114025f6..2a6e713c 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -1,5 +1,6 @@ +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt -import napari import numpy as np from matplotlib.backends.backend_qt5agg import ( FigureCanvasQTAgg as FigureCanvas, @@ -8,9 +9,12 @@ from monai.transforms import SpatialPad, ToTensor from tifffile import imread +if TYPE_CHECKING: + import napari + from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.instance_segmentation import to_semantic from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder DEFAULT_THRESHOLD = 0.5 @@ -187,11 +191,11 @@ def compute_dice(self): self.canvas = ( None # kind of terrible way to stack plots... but it works. ) - id = 0 + image_id = 0 for ground_path, pred_path in zip( self.images_filepaths, self.labels_filepaths ): - id += 1 + image_id += 1 ground = imread(ground_path) pred = imread(pred_path) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index df64a625..bb46617d 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,18 +1,21 @@ from functools import partial +from typing import TYPE_CHECKING -import napari import numpy as np import pandas as pd +if TYPE_CHECKING: + import napari + # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceMethod, InstanceWidgets, ) -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.workers import ( InferenceResult, InferenceWorker, ) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 86d1d317..35a16799 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,9 +1,9 @@ import shutil from functools import partial from pathlib import Path +from typing import TYPE_CHECKING import matplotlib.pyplot as plt -import napari import numpy as np import pandas as pd import torch @@ -12,6 +12,9 @@ ) from matplotlib.figure import Figure +if TYPE_CHECKING: + import napari + # MONAI from monai.losses import ( DiceCELoss, @@ -29,7 +32,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( TrainingReport, TrainingWorker, ) @@ -414,11 +417,10 @@ def check_ready(self): * False and displays a warning if not """ - if self.images_filepaths != [] and self.labels_filepaths != []: - return True - else: + if self.images_filepaths == [] and self.labels_filepaths != []: logger.warning("Image and label paths are not correctly set") return False + return True def _build(self): """Builds the layout of the widget and creates the following tabs and prompts: @@ -999,7 +1001,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1131,7 +1133,7 @@ def update_loss_plot(self, loss, metric): epoch = len(loss) if epoch < self.worker_config.validation_interval * 2: return - elif epoch == self.worker_config.validation_interval * 2: + if epoch == self.worker_config.validation_interval * 2: bckgrd_color = (0, 0, 0, 0) # '#262930' with plt.style.context("dark_background"): self.canvas = FigureCanvas(Figure(figsize=(10, 1.5))) @@ -1167,7 +1169,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 235595e4..dd98bcd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -178,11 +178,10 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None: - if cfg.image.shape != cfg.labels.shape: - logger.warning( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: + logger.warning( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): @@ -400,7 +399,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index 8753a642..f634d117 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,9 +1,12 @@ from datetime import datetime, timedelta from pathlib import Path +from typing import TYPE_CHECKING -import napari import pandas as pd +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QVBoxLayout, QWidget diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 8a7c1565..5c0b34be 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,7 +6,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.instance_segmentation import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ @@ -89,9 +89,9 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info( - "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) - ) + logger.info("Model list :") + for model_name in MODEL_LIST: + logger.info(f" * {model_name}") return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b4712aec..93746eb6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ -import os +import os # TODO(cyril): remove os +from pathlib import Path import napari import numpy as np @@ -6,7 +7,7 @@ from skimage.filters import threshold_otsu from tifffile import imread, imwrite -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -289,18 +290,13 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): ndarray Label image with artefacts labelled and small artefacts removed. """ - if not is_labeled: - # find all the connected components in the artefacts image - labels = ndimage.label(artefacts)[0] - else: - labels = artefacts + labels = ndimage.label(artefacts)[0] if not is_labeled else artefacts # remove the small components labels_i, counts = np.unique(labels, return_counts=True) labels_i = labels_i[counts > min_size] labels_i = labels_i[labels_i > 0] - artefacts = np.where(np.isin(labels, labels_i), labels, 0) - return artefacts + return np.where(np.isin(labels, labels_i), labels, 0) def create_artefact_labels( @@ -388,7 +384,7 @@ def create_artefact_labels_from_folder( path_labels.sort() path_images.sort() # create the output folder - os.makedirs(path + "/artefact_neurons", exist_ok=True) + Path().mkdir(path + "/artefact_neurons", exist_ok=True) # create the artefact labels for i in range(len(path_images)): print(path_labels[i]) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 9862c3fa..4a7363b2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -12,7 +12,7 @@ from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) @@ -228,10 +228,7 @@ def relabel( print("these labels will be added") if test: viewer.close() - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer + viewer = napari.view_image(image) if viewer is None else viewer if not test: viewer.add_labels(artefact_copy, name="labels added") napari.run() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 6c5eb5c3..df00ad0b 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,3 +1,4 @@ +import contextlib import threading from functools import partial from typing import List, Optional @@ -104,12 +105,12 @@ def __call__(cls, *args, **kwargs): ################## -def handle_adjust_errors(widget, type, context, msg: str): +def handle_adjust_errors(widget, warning_type, context, msg: str): """Qt message handler that attempts to react to errors when setting the window size and resizes the main window""" pass # head = msg.split(": ")[0] - # if type == QtWarningMsg and head == "QWindowsWindow::setGeometry": + # if warning_type == QtWarningMsg and head == "QWindowsWindow::setGeometry": # logger.warning( # f"Qt resize error : {msg}\nhas been handled by attempting to resize the window" # ) @@ -332,8 +333,7 @@ def toggle_visibility(checkbox, widget): def add_label(widget, label, label_before=True, horizontal=True): if label_before: return combine_blocks(widget, label, horizontal=horizontal) - else: - return combine_blocks(label, widget, horizontal=horizontal) + return combine_blocks(label, widget, horizontal=horizontal) class ContainerWidget(QWidget): @@ -735,8 +735,7 @@ def anisotropy_zoom_factor(aniso_res): """ base = min(aniso_res) - zoom_factors = [base / res for res in aniso_res] - return zoom_factors + return [base / res for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" @@ -796,8 +795,8 @@ def _remove_layer(self, event): index = self.layer_list.findText(removed_layer.name) self.layer_list.removeItem(index) - def set_layer_type(self, type): # no @property due to Qt constraint - self.layer_type = type + def set_layer_type(self, layer_type): # no @property due to Qt constraint + self.layer_type = layer_type [self.layer_list.removeItem(i) for i in range(self.layer_list.count())] self._check_for_layers() @@ -810,7 +809,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: logger.warning("Please select a valid layer !") - return + return None return self.layer().data @@ -898,9 +897,8 @@ def check_ready(self): self.update_field_color("indianred") self.text_field.setToolTip("Mandatory field !") return False - else: - self.update_field_color(f"{napari_param_darkgrey}") - return True + self.update_field_color(f"{napari_param_darkgrey}") + return True @property def required(self): @@ -912,10 +910,9 @@ def required(self, is_required): if is_required: self.text_field.textChanged.connect(self.check_ready) else: - try: + with contextlib.suppress(TypeError): self.text_field.textChanged.disconnect(self.check_ready) - except TypeError: - pass + self.check_ready() self._required = is_required @@ -1002,22 +999,22 @@ def make_scrollable( def set_spinbox( box, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, fixed: Optional[bool] = True, ): """Args: box : QSpinBox or QDoubleSpinBox - min : minimum value, defaults to 0 - max : maximum value, defaults to 10 + min_value : minimum value, defaults to 0 + max_value : maximum value, defaults to 10 default : default value, defaults to 0 step : step value, defaults to 1 fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" - box.setMinimum(min) - box.setMaximum(max) + box.setMinimum(min_value) + box.setMaximum(max_value) box.setSingleStep(step) box.setValue(default) @@ -1028,8 +1025,8 @@ def set_spinbox( def make_n_spinboxes( class_, n: int = 2, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, parent: Optional[QWidget] = None, @@ -1040,8 +1037,8 @@ def make_n_spinboxes( Args: class_ : QSpinBox or QDoubleSpinbox n (int): number of increment counters to create - min (Optional[int]): minimum value, defaults to 0 - max (Optional[int]): maximum value, defaults to 10 + min_value (Optional[int]): minimum value, defaults to 0 + max_value (Optional[int]): maximum value, defaults to 10 default (Optional[int]): default value, defaults to 0 step (Optional[int]): step value, defaults to 1 parent: parent widget, defaults to None @@ -1052,7 +1049,7 @@ def make_n_spinboxes( boxes = [] for _i in range(n): - box = class_(min, max, default, step, parent, fixed) + box = class_(min_value, max_value, default, step, parent, fixed) boxes.append(box) return boxes @@ -1225,10 +1222,9 @@ def open_file_dialog( default_path = utils.parse_default_path(possible_paths) - f_name = QFileDialog.getOpenFileName( + return QFileDialog.getOpenFileName( widget, "Choose file", default_path, filetype ) - return f_name def open_folder_dialog( @@ -1238,10 +1234,9 @@ def open_folder_dialog( default_path = utils.parse_default_path(possible_paths) logger.info(f"Default : {default_path}") - filenames = QFileDialog.getExistingDirectory( + return QFileDialog.getExistingDirectory( widget, "Open directory", default_path + "/.." ) - return filenames def make_label(name, parent=None): # TODO update to child class @@ -1258,12 +1253,11 @@ def make_label(name, parent=None): # TODO update to child class label = QLabel(name, parent) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label else: label = QLabel(name) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label + return label def make_group(title, l=7, t=20, r=7, b=11, parent=None): diff --git a/pyproject.toml b/pyproject.toml index 81d2a788..7210af6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ exclude = [ "dist", "node_modules", "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", ] [tool.black] From 0c2a9dedc6b9fc8f6ae1b8b223849cfebcb44b86 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:06:24 +0200 Subject: [PATCH 408/577] Small docs update --- docs/index.rst | 4 +- docs/res/code/instance_segmentation.rst | 53 +++++++++++++++++++ docs/res/code/model_instance_seg.rst | 53 ------------------- docs/res/code/plugin_convert.rst | 15 ------ docs/res/code/utils.rst | 4 -- .../code/{model_workers.rst => workers.rst} | 8 +-- docs/res/guides/custom_model_template.rst | 28 +++++++++- docs/res/guides/detailed_walkthrough.rst | 4 +- docs/res/guides/inference_module_guide.rst | 2 +- docs/res/guides/training_module_guide.rst | 2 +- napari_cellseg3d/code_models/workers.py | 28 +++++----- 11 files changed, 105 insertions(+), 96 deletions(-) create mode 100644 docs/res/code/instance_segmentation.rst delete mode 100644 docs/res/code/model_instance_seg.rst rename docs/res/code/{model_workers.rst => workers.rst} (78%) diff --git a/docs/index.rst b/docs/index.rst index 7e809fbe..46c57c08 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,8 +39,8 @@ Welcome to napari-cellseg3d's documentation! res/code/plugin_convert res/code/plugin_metrics res/code/model_framework - res/code/model_workers - res/code/model_instance_seg + res/code/workers + res/code/instance_segmentation res/code/plugin_model_inference res/code/plugin_model_training res/code/utils diff --git a/docs/res/code/instance_segmentation.rst b/docs/res/code/instance_segmentation.rst new file mode 100644 index 00000000..143560c4 --- /dev/null +++ b/docs/res/code/instance_segmentation.rst @@ -0,0 +1,53 @@ +instance_segmentation.py +=========================================== + +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::VoronoiOtsu + :members: __init__ + + +Functions +------------- + +binary_connected +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_connected + +binary_watershed +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_watershed + +volume_stats +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::volume_stats + +clear_small_objects +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::clear_small_objects + +to_instance +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_instance + +to_semantic +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_semantic diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst deleted file mode 100644 index 3b323173..00000000 --- a/docs/res/code/model_instance_seg.rst +++ /dev/null @@ -1,53 +0,0 @@ -model_instance_seg.py -=========================================== - -Classes -------------- - -InstanceMethod -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod - :members: __init__ - -ConnectedComponents -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents - :members: __init__ - -Watershed -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed - :members: __init__ - -VoronoiOtsu -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu - :members: __init__ - - -Functions -------------- - -binary_connected -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_connected - -binary_watershed -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_watershed - -volume_stats -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::volume_stats - -clear_small_objects -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::clear_small_objects - -to_instance -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_instance - -to_semantic -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_semantic diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 03944510..25006d0f 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -28,18 +28,3 @@ ThresholdUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ThresholdUtils :members: __init__ - -Functions ------------------------------------ - -save_folder -***************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_folder - -save_layer -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_layer - -show_result -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index e90ee7e0..d9fdcfa2 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -62,7 +62,3 @@ denormalize_y load_images ************************************** .. autofunction:: napari_cellseg3d.utils::load_images - -format_Warning -************************************** -.. autofunction:: napari_cellseg3d.utils::format_Warning diff --git a/docs/res/code/model_workers.rst b/docs/res/code/workers.rst similarity index 78% rename from docs/res/code/model_workers.rst rename to docs/res/code/workers.rst index 85f8da29..1f5167ad 100644 --- a/docs/res/code/model_workers.rst +++ b/docs/res/code/workers.rst @@ -1,4 +1,4 @@ -model_workers.py +workers.py =========================================== @@ -10,7 +10,7 @@ Class : LogSignal Attributes ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::LogSignal +.. autoclass:: napari_cellseg3d.code_models.workers::LogSignal :members: log_signal :noindex: @@ -24,7 +24,7 @@ Class : InferenceWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::InferenceWorker +.. autoclass:: napari_cellseg3d.code_models.workers::InferenceWorker :members: __init__, log, create_inference_dict, inference :noindex: @@ -39,6 +39,6 @@ Class : TrainingWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::TrainingWorker +.. autoclass:: napari_cellseg3d.code_models.workers::TrainingWorker :members: __init__, log, train :noindex: diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index 218795b1..a70df29b 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -3,9 +3,33 @@ Advanced : Declaring a custom model ============================================= -To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder: +.. warning:: + **WIP** : Adding new models is still a work in progress and will likely not work simply by adding the model in the plugin. + + Please `file an issue`_ if you would like to add a custom model and we will help you get it working. + +To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder:: + + class ModelTemplate_(ABC): # replace ABC with your PyTorch model class name + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) # download URL goes in pretrained_models.json + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass + .. note:: **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute -:: +.. _file an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues diff --git a/docs/res/guides/detailed_walkthrough.rst b/docs/res/guides/detailed_walkthrough.rst index 407893c2..3d06d998 100644 --- a/docs/res/guides/detailed_walkthrough.rst +++ b/docs/res/guides/detailed_walkthrough.rst @@ -1,6 +1,6 @@ .. _detailed_walkthrough: -Detailed walkthrough +Detailed walkthrough - Supervised learning =================================== The following guide will show you how to use the plugin's workflow, starting from human-labeled annotation volume, to running inference on novel volumes. @@ -109,7 +109,7 @@ of two no matter the size you choose. For optimal performance, make sure to use a power of two still, such as 64 or 120. .. important:: - Using a too large value for the size will cause memory issues. If this happens, restart napari (better handling for these situations might be added in the future). + Using a too large value for the size will cause memory issues. If this happens, restart the worker with smaller volumes. You also have the option to use data augmentation, which can improve performance and generalization. In most cases this should left enabled. diff --git a/docs/res/guides/inference_module_guide.rst b/docs/res/guides/inference_module_guide.rst index 00e67078..373e9d0d 100644 --- a/docs/res/guides/inference_module_guide.rst +++ b/docs/res/guides/inference_module_guide.rst @@ -132,4 +132,4 @@ Source code -------------------------------- * :doc:`../code/plugin_model_inference` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index 05ce69be..1038dc6d 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -128,4 +128,4 @@ Source code -------------------------------- * :doc:`../code/plugin_model_training` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c1ed62fd..e2e21363 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -61,16 +61,6 @@ logger = utils.LOGGER -""" -Writing something to log messages from outside the main thread is rather problematic (plenty of silent crashes...) -so instead, following the instructions in the guides below to have a worker with custom signals, I implemented -a custom worker function.""" - -# FutureReference(): -# https://python-forum.io/thread-31349.html -# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ -# https://napari-staging-site.github.io/guides/stable/threading.html - PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) @@ -174,12 +164,23 @@ def safe_extract( ) +""" +Writing something to log messages from outside the main thread needs specific care, +Following the instructions in the guides below to have a worker with custom signals, +a custom worker function was implemented. +""" + +# https://python-forum.io/thread-31349.html +# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ +# https://napari-staging-site.github.io/guides/stable/threading.html + + class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_ + Separate from Worker instances as indicated `on this post`_ - .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect """ # TODO link ? log_signal = Signal(str) @@ -196,6 +197,9 @@ def __init__(self): super().__init__() +# TODO(cyril): move inference and training workers to separate files + + @dataclass class InferenceResult: """Class to record results of a segmentation job""" From 4a8188273218345b8668b3ef47bf94a92c4c80d2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:24:43 +0200 Subject: [PATCH 409/577] Testing fix --- napari_cellseg3d/code_models/instance_segmentation.py | 5 ++--- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 5de7ab0c..2240e3bd 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -102,11 +102,10 @@ def _make_list_from_channels( image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - return [image] - return None + return [image] def run_method_on_channels(self, image): - image_list = self._make_list_from_channels(image) # FIXME rename + image_list = self._make_list_from_channels(image) result = np.array([self.run_method(im) for im in image_list]) return result.squeeze() diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 7235bd61..cb5ef6d8 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,7 @@ def __init__( num_classes=num_classes, ) - # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): From b9b377a953ee9096e665cf88a539c8d5888b6314 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:59:05 +0200 Subject: [PATCH 410/577] Fixed multithread testing (locally) --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/_tests/test_models.py | 14 +- .../_tests/test_plugin_inference.py | 29 ++-- napari_cellseg3d/_tests/test_training.py | 27 ++-- .../code_plugins/plugin_model_inference.py | 125 ++++++++++-------- .../code_plugins/plugin_model_training.py | 108 ++++++++------- 6 files changed, 158 insertions(+), 146 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 88a67ae2..fa6905d5 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -9,6 +9,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35174b85..4852f651 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -52,7 +52,7 @@ def test_crf(qtbot): mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 - crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) + crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) @@ -60,20 +60,20 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=True ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() mock_image = mock_image[0] mock_label = mock_label[0] crf = CRFWorker(mock_image, mock_label) - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=False ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 3dafeabc..d1264218 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,10 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_inference(make_napari_viewer, qtbot): @@ -29,14 +28,16 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.setCurrentIndex(-1) - - # widget.start() # takes too long on Github Actions - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: - # blocker.connect(widget.worker.errored) - - #### assert len(viewer.layers) == 2 + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.setCurrentIndex(-1) + + widget.worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(widget.config) + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() # takes too long on Github Actions + assert widget.worker is not None + # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 921a6d26..4d558363 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,10 +2,9 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_training(make_napari_viewer, qtbot): @@ -33,15 +32,19 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) - - # widget.start() - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds. - # blocker.connect(widget.worker.errored) + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + + worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(worker_config) + + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() + assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index bb46617d..ba23e1df 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -551,64 +551,7 @@ def start(self): self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - self.model_info = config.ModelInfo( - name=self.model_choice.currentText(), - model_input_size=self.model_input_size.value(), - ) - - self.weights_config.custom = self.custom_weights_choice.isChecked() - - save_path = self.results_filewidget.text_field.text() - if not self._check_results_path(save_path): - msg = f"ERROR: please set valid results path. Current path is {save_path}" - self.log.print_and_log(msg) - logger.warning(msg) - else: - if self.results_path is None: - self.results_path = save_path - - zoom_config = config.Zoom( - enabled=self.anisotropy_wdgt.enabled(), - zoom_values=self.anisotropy_wdgt.scaling_xyz(), - ) - thresholding_config = config.Thresholding( - enabled=self.thresholding_checkbox.isChecked(), - threshold_value=self.thresholding_slider.slider_value, - ) - - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] - ) - - self.post_process_config = config.PostProcessConfig( - zoom=zoom_config, - thresholding=thresholding_config, - instance=self.instance_config, - ) - - if self.window_infer_box.isChecked(): - size = int(self.window_size_choice.currentText()) - window_config = config.SlidingWindowConfig( - window_size=size, - window_overlap=self.window_overlap_slider.slider_value, - ) - else: - window_config = config.SlidingWindowConfig() - - self.worker_config = config.InferenceWorkerConfig( - device=self.get_device(), - model_info=self.model_info, - weights_config=self.weights_config, - results_path=self.results_path, - filetype=self.filetype_choice.currentText(), - keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), - compute_stats=self.save_stats_to_csv_box.isChecked(), - post_process_config=self.post_process_config, - sliding_window_config=window_config, - use_crf=self.use_crf.isChecked(), - crf_config=self.crf_widgets.make_config(), - ) + self._set_worker_config() ##################### ##################### ##################### @@ -650,6 +593,72 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config(self, config: config.InferenceWorkerConfig): + return InferenceWorker(worker_config=config) + + def _set_worker_config(self) -> config.InferenceWorkerConfig: + self.model_info = config.ModelInfo( + name=self.model_choice.currentText(), + model_input_size=self.model_input_size.value(), + ) + + self.weights_config.custom = self.custom_weights_choice.isChecked() + + save_path = self.results_filewidget.text_field.text() + if not self._check_results_path(save_path): + msg = f"ERROR: please set valid results path. Current path is {save_path}" + self.log.print_and_log(msg) + logger.warning(msg) + else: + if self.results_path is None: + self.results_path = save_path + + zoom_config = config.Zoom( + enabled=self.anisotropy_wdgt.enabled(), + zoom_values=self.anisotropy_wdgt.scaling_xyz(), + ) + thresholding_config = config.Thresholding( + enabled=self.thresholding_checkbox.isChecked(), + threshold_value=self.thresholding_slider.slider_value, + ) + + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], + ) + + self.post_process_config = config.PostProcessConfig( + zoom=zoom_config, + thresholding=thresholding_config, + instance=self.instance_config, + ) + + if self.window_infer_box.isChecked(): + size = int(self.window_size_choice.currentText()) + window_config = config.SlidingWindowConfig( + window_size=size, + window_overlap=self.window_overlap_slider.slider_value, + ) + else: + window_config = config.SlidingWindowConfig() + + self.worker_config = config.InferenceWorkerConfig( + device=self.get_device(), + model_info=self.model_info, + weights_config=self.weights_config, + results_path=self.results_path, + filetype=self.filetype_choice.currentText(), + keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), + compute_stats=self.save_stats_to_csv_box.isChecked(), + post_process_config=self.post_process_config, + sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), + ) + return self.worker_config + def on_start(self): """Catches start signal from worker to call :py:func:`~display_status_report`""" self.display_status_report() diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 35a16799..e11eb3de 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -808,64 +808,10 @@ def start(self): self.data = None raise err - model_config = config.ModelInfo( - name=self.model_choice.currentText() - ) - - self.weights_config.path = self.weights_config.path - self.weights_config.custom = self.custom_weights_choice.isChecked() - self.weights_config.use_pretrained = ( - not self.use_transfer_choice.isChecked() - ) - - deterministic_config = config.DeterministicConfig( - enabled=self.use_deterministic_choice.isChecked(), - seed=self.box_seed.value(), - ) - - validation_percent = ( - self.validation_percent_choice.slider_value / 100 - ) - - results_path_folder = Path( - self.results_path - + f"/{model_config.name}_{utils.get_date_time()}" - ) - Path(results_path_folder).mkdir( - parents=True, exist_ok=False - ) # avoid overwrite where possible - - patch_size = [w.value() for w in self.patch_size_widgets] - - logger.debug("Loading config...") - self.worker_config = config.TrainingWorkerConfig( - device=self.get_device(), - model_info=model_config, - weights_info=self.weights_config, - train_data_dict=self.data, - validation_percent=validation_percent, - max_epochs=self.epoch_choice.value(), - loss_function=self.get_loss(self.loss_choice.currentText()), - learning_rate=float(self.learning_rate_choice.currentText()), - scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.slider_value, - validation_interval=self.val_interval_choice.value(), - batch_size=self.batch_choice.slider_value, - results_path_folder=str(results_path_folder), - sampling=self.patch_choice.isChecked(), - num_samples=self.sample_choice_slider.slider_value, - sample_size=patch_size, - do_augmentation=self.augment_choice.isChecked(), - deterministic_config=deterministic_config, - ) # TODO(cyril) continue to put params in config - self.config = config.TrainerConfig( save_as_zip=self.zip_choice.isChecked() ) - - self.log.print_and_log( - f"Saving results to : {results_path_folder}" - ) + self._set_worker_config() self.worker = TrainingWorker(config=self.worker_config) self.worker.set_download_log(self.log) @@ -895,6 +841,58 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config(self, config: config.TrainingWorkerConfig): + return TrainingWorker(config=config) + + def _set_worker_config(self) -> config.TrainingWorkerConfig: + model_config = config.ModelInfo(name=self.model_choice.currentText()) + + self.weights_config.path = self.weights_config.path + self.weights_config.custom = self.custom_weights_choice.isChecked() + self.weights_config.use_pretrained = ( + not self.use_transfer_choice.isChecked() + ) + + deterministic_config = config.DeterministicConfig( + enabled=self.use_deterministic_choice.isChecked(), + seed=self.box_seed.value(), + ) + + validation_percent = self.validation_percent_choice.slider_value / 100 + + results_path_folder = Path( + self.results_path + f"/{model_config.name}_{utils.get_date_time()}" + ) + Path(results_path_folder).mkdir( + parents=True, exist_ok=False + ) # avoid overwrite where possible + + patch_size = [w.value() for w in self.patch_size_widgets] + + logger.debug("Loading config...") + self.worker_config = config.TrainingWorkerConfig( + device=self.get_device(), + model_info=model_config, + weights_info=self.weights_config, + train_data_dict=self.data, + validation_percent=validation_percent, + max_epochs=self.epoch_choice.value(), + loss_function=self.get_loss(self.loss_choice.currentText()), + learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, + validation_interval=self.val_interval_choice.value(), + batch_size=self.batch_choice.slider_value, + results_path_folder=str(results_path_folder), + sampling=self.patch_choice.isChecked(), + num_samples=self.sample_choice_slider.slider_value, + sample_size=patch_size, + do_augmentation=self.augment_choice.isChecked(), + deterministic_config=deterministic_config, + ) # TODO(cyril) continue to put params in config + + return self.worker_config + def on_start(self): """Catches started signal from worker""" From ecc127fdc5e48a78b020ddb2d50bba0ac1b60787 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:06:02 +0200 Subject: [PATCH 411/577] Added proper tests for train/infer --- .../_tests/test_plugin_inference.py | 36 ++++++++++++++----- napari_cellseg3d/_tests/test_training.py | 34 ++++++++++++------ napari_cellseg3d/code_models/workers.py | 4 +-- .../code_plugins/plugin_model_inference.py | 8 +++-- .../code_plugins/plugin_model_training.py | 10 ++++-- 5 files changed, 67 insertions(+), 25 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index d1264218..04305082 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -4,7 +4,10 @@ from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.code_plugins.plugin_model_inference import ( + InferenceResult, + Inferer, +) from napari_cellseg3d.config import MODEL_LIST @@ -28,16 +31,31 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(widget.config) - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() # takes too long on Github Actions - assert widget.worker is not None + assert widget.worker_config is not None + assert widget.model_info is not None + worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None + assert worker.config.model_info is not None + worker.config.layer = viewer.layers[0].data + assert worker.config.layer is not None + worker.log_parameters() + + res = next(worker.inference()) + assert isinstance(res, InferenceResult) + assert res.result.shape == (6, 6, 6) + + # def on_error(e): + # print(e) + # assert False + # with qtbot.waitSignal( + # signal=worker.finished, timeout=10000, raising=True + # ) as blocker: + # worker.error_signal.connect(on_error) + # blocker.connect(worker.errored) + # worker.inference() # takes too long on Github Actions # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 4d558363..080df419 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -3,7 +3,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d.code_plugins.plugin_model_training import ( + Trainer, + TrainingReport, +) from napari_cellseg3d.config import MODEL_LIST @@ -32,19 +35,30 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(worker_config) - - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() - assert widget.worker is not None + worker = widget._create_worker_from_config(worker_config) + worker.config.train_data_dict = [{"image": im_path, "label": im_path}] + worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.log_parameters() + res = next(worker.train()) + + assert isinstance(res, TrainingReport) + + # def on_error(e): + # print(e) + # assert False + # + # with qtbot.waitSignal( + # signal=widget.worker.finished, timeout=10000, raising=True + # ) as blocker: + # blocker.connect(widget.worker.errored) + # widget.worker.error_signal.connect(on_error) + # widget.worker.train() + # assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index e2e21363..6dd32c80 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -965,7 +965,7 @@ class TrainingWorker(GeneratorWorker): def __init__( self, - config: config.TrainingWorkerConfig, + worker_config: config.TrainingWorkerConfig, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train` @@ -1012,7 +1012,7 @@ def __init__( self._weight_error = False ############################################# - self.config = config + self.config = worker_config self.train_files = [] self.val_files = [] diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index ba23e1df..1d8c0620 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -593,8 +593,12 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.InferenceWorkerConfig): - return InferenceWorker(worker_config=config) + def _create_worker_from_config( + self, worker_config: config.InferenceWorkerConfig + ): + if isinstance(worker_config, config.InfererConfig): + raise TypeError("Please provide a valid worker config object") + return InferenceWorker(worker_config=worker_config) def _set_worker_config(self) -> config.InferenceWorkerConfig: self.model_info = config.ModelInfo( diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index e11eb3de..2a131a5f 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -841,8 +841,14 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.TrainingWorkerConfig): - return TrainingWorker(config=config) + def _create_worker_from_config( + self, worker_config: config.TrainingWorkerConfig + ): + if isinstance(config, config.TrainerConfig): + raise TypeError( + "Expected a TrainingWorkerConfig, got a TrainerConfig" + ) + return TrainingWorker(worker_config=worker_config) def _set_worker_config(self) -> config.TrainingWorkerConfig: model_config = config.ModelInfo(name=self.model_choice.currentText()) From 7188c64d6f5449bdb330616bc0519fe836b33d40 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:31:36 +0200 Subject: [PATCH 412/577] Slight coverage increase --- napari_cellseg3d/_tests/test_plugin_inference.py | 13 ++----------- napari_cellseg3d/_tests/test_training.py | 1 + napari_cellseg3d/code_models/models/model_test.py | 2 +- napari_cellseg3d/code_models/workers.py | 6 +++--- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 04305082..c437ac83 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -39,23 +39,14 @@ def test_inference(make_napari_viewer, qtbot): assert widget.worker_config is not None assert widget.model_info is not None worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data + worker.config.post_process_config.instance.enabled = True assert worker.config.layer is not None worker.log_parameters() res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) - - # def on_error(e): - # print(e) - # assert False - # with qtbot.waitSignal( - # signal=worker.finished, timeout=10000, raising=True - # ) as blocker: - # worker.error_signal.connect(on_error) - # blocker.connect(worker.errored) - # worker.inference() # takes too long on Github Actions - # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 080df419..e7f1e07b 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -43,6 +43,7 @@ def test_training(make_napari_viewer, qtbot): worker = widget._create_worker_from_config(worker_config) worker.config.train_data_dict = [{"image": im_path, "label": im_path}] worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.config.max_epochs = 1 worker.log_parameters() res = next(worker.train()) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 1ccac3da..1cb52f06 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -8,7 +8,7 @@ class TestModel(nn.Module): def __init__(self, **kwargs): super().__init__() - self.linear = nn.Linear(1, 1) + self.linear = nn.Linear(8, 8) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 6dd32c80..8ddc7921 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1425,9 +1425,9 @@ def get_loader_func(num_samples): device = self.config.device - if model_name == "test": - self.quit() - yield TrainingReport(False) + # if model_name == "test": + # self.quit() + # yield TrainingReport(False) for epoch in range(self.config.max_epochs): # self.log("\n") From bd72e72a7f6d7edd8bddaef1d748c9a1116fca44 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:45:47 +0200 Subject: [PATCH 413/577] Update test_plugin_inference.py --- napari_cellseg3d/_tests/test_plugin_inference.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index c437ac83..ca8e84d4 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,6 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.instance_segmentation import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import ( InferenceResult, @@ -44,6 +47,10 @@ def test_inference(make_napari_viewer, qtbot): assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data worker.config.post_process_config.instance.enabled = True + worker.config.post_process_config.instance.method = ( + INSTANCE_SEGMENTATION_METHOD_LIST["Watershed"]() + ) + assert worker.config.layer is not None worker.log_parameters() From 17426c99021138509b0608f52d8e02d983c9520f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 17 May 2023 11:41:39 +0200 Subject: [PATCH 414/577] Set window inference to 64 for WNet --- .../code_plugins/plugin_model_inference.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 1d8c0620..74dc62e5 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -119,6 +119,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.model_choice.currentIndexChanged.connect( self._toggle_display_model_input_size ) + self.model_choice.currentIndexChanged.connect( + self._restrict_window_size_for_model + ) self.model_choice.setCurrentIndex(0) self.anisotropy_wdgt = ui.AnisotropyWidgets( @@ -150,9 +153,10 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.window_infer_box = ui.CheckBox("Use window inference") - self.window_infer_box.clicked.connect(self._toggle_display_window_size) + self.window_infer_box.toggled.connect(self._toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] + self._default_window_size = sizes_window.index("64") # ( # self.window_size_choice, # self.window_size_choice.label, @@ -167,7 +171,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(3) # set to 64 by default + self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -192,7 +196,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) - self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## @@ -299,6 +302,19 @@ def check_ready(self): return True return False + def _restrict_window_size_for_model(self): + """Sets the window size to a value that is compatible with the chosen model""" + if self.model_choice.currentText() == "WNet": + self.window_size_choice.setCurrentIndex(self._default_window_size) + self.window_size_choice.setDisabled(True) + self.window_infer_box.setChecked(True) + self.window_infer_box.setDisabled(True) + else: + self.window_size_choice.setDisabled(False) + self.window_infer_box.setDisabled(False) + self.window_infer_box.setChecked(False) + self.window_size_choice.setCurrentIndex(self._default_window_size) + def _toggle_display_model_input_size(self): if ( self.model_choice.currentText() == "SegResNet" From eb6b199d5cc86e499dc43868f2be4e5b87b068c7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 17 May 2023 22:00:16 +0200 Subject: [PATCH 415/577] Update instance_segmentation.py --- napari_cellseg3d/code_models/instance_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 2240e3bd..93de0768 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -434,7 +434,7 @@ def sphericity(region): return ImageStats( volume, [region.centroid[0] for region in properties], - [region.centroid[0] for region in properties], + [region.centroid[1] for region in properties], [region.centroid[2] for region in properties], sphericity_ax, fill([volume_image.shape]), From ad4069c2635741d133f7657b05588df8ada2c993 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 20 May 2023 09:22:52 +0200 Subject: [PATCH 416/577] Moved normalization to the correct place --- napari_cellseg3d/code_models/workers.py | 2 +- napari_cellseg3d/utils.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 8ddc7921..dd9e38e3 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -492,9 +492,9 @@ def model_output( logger.debug(f"inputs type : {inputs.dtype}") try: # outputs = model(inputs) + inputs = utils.remap_image(inputs) def model_output_wrapper(inputs): - inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 7ca29e00..90a64cfb 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -223,12 +223,18 @@ def normalize_max(image): def remap_image( - image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 + image: Union["np.ndarray", "torch.Tensor"], + new_max=100, + new_min=0, + prev_max=None, + prev_min=None, ): """Normalizes a numpy array or Tensor using the max and min value""" shape = image.shape image = image.flatten() - image = (image - image.min()) / (image.max() - image.min()) + im_max = prev_max if prev_max is not None else image.max() + im_min = prev_min if prev_min is not None else image.min() + image = (image - im_min) / (im_max - im_min) image = image * (new_max - new_min) + new_min image = image.reshape(shape) return image From 6c3e43864087f3484811c7ceb75515a487e6968c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 11:09:48 +0200 Subject: [PATCH 417/577] Added auto-set dims for cropping --- napari_cellseg3d/code_plugins/plugin_crop.py | 38 +++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index a27b4baa..e3ea55f5 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -3,6 +3,7 @@ import napari import numpy as np from magicgui import magicgui +from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -43,6 +44,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") + self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -112,6 +114,8 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self._build() self._toggle_second_image_io_visibility() + self._check_image_list() + self.auto_set_dims() def _toggle_second_image_io_visibility(self): crop_2nd = self.crop_second_image_choice.isChecked() @@ -132,6 +136,16 @@ def _check_image_list(self): except IndexError: return + def auto_set_dims(self): + logger.debug(self.image_layer_loader.layer_name()) + data = self.image_layer_loader.layer_data() + if data is not None: + logger.debug("auto_set_dims : {}".format(data.shape)) + if len(data.shape) == 3: + for i, box in enumerate(self.crop_size_widgets): + logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") + box.setValue(floor(data.shape[i] / 2)) + def _build(self): """Build buttons in a layout and add them to the napari Viewer""" @@ -266,9 +280,9 @@ def _start(self): except ValueError as e: logger.warning(e) logger.warning( - "Could not remove cropping layer programmatically!" + "Could not remove the previous cropping layer programmatically." ) - logger.warning("Maybe layer has been removed by user?") + # logger.warning("Maybe layer has been removed by user?") self.results_path = Path(self.results_filewidget.text_field.text()) @@ -346,7 +360,7 @@ def add_isotropic_layer( layer.data, name=f"Scaled_{layer.name}", colormap=colormap, - contrast_limits=contrast_lim, + # contrast_limits=contrast_lim, opacity=opacity, scale=self.aniso_factors, visible=visible, @@ -481,8 +495,8 @@ def set_slice( """ "Update cropped volume position""" # self._check_for_empty_layer(highres_crop_layer, highres_crop_layer.data) - logger.debug(f"axis : {axis}") - logger.debug(f"value : {value}") + # logger.debug(f"axis : {axis}") + # logger.debug(f"value : {value}") idx = int(value) scale = np.asarray(highres_crop_layer.scale) @@ -496,6 +510,20 @@ def set_slice( cropy = self._crop_size_y cropz = self._crop_size_z + if i + cropx > im1_stack.shape[0]: + cropx = im1_stack.shape[0] - i + if j + cropy > im1_stack.shape[1]: + cropy = im1_stack.shape[1] - j + if k + cropz > im1_stack.shape[2]: + cropz = im1_stack.shape[2] - k + + logger.debug(f"cropx : {cropx}") + logger.debug(f"cropy : {cropy}") + logger.debug(f"cropz : {cropz}") + logger.debug(f"i : {i}") + logger.debug(f"j : {j}") + logger.debug(f"k : {k}") + highres_crop_layer.data = im1_stack[ i : i + cropx, j : j + cropy, k : k + cropz ] From 6c82e2b035d711dd44b35fff40e76518842e2f4d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 12:19:37 +0200 Subject: [PATCH 418/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 21 +++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 584be4d7..60c25ccc 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,27 +1,34 @@ -from pathlib import Path - import numpy as np -from tifffile import imread +from numpy.random import PCG64, Generator from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, Utilities, ) +rand_gen = Generator(PCG64(12345)) + def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) - im_path = str(Path(__file__).resolve().parent / "res/test.tif") - image = imread(im_path) - view.add_image(image) - view.add_labels(image.astype(np.uint8)) + image = rand_gen.random((10, 10, 10)).astype(np.uint8) + image_layer = view.add_image(image, name="image") + label_layer = view.add_labels(image.astype(np.uint8), name="labels") view.window.add_dock_widget(widget) + view.dims.ndisplay = 3 for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): widget.utils_choice.setCurrentIndex(i) assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + + assert len(image_layer.data.shape) == 3 + assert len(label_layer.data.shape) == 3 widget.utils_widgets[i]._start() From 23bb52fd89935d049161c09e880c13d589af9d3e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 15:50:18 +0200 Subject: [PATCH 419/577] More WNet - Added experimental .pt loading for jit models - More CRF tests - Optimized WNet by loading inference only --- napari_cellseg3d/_tests/test_models.py | 61 ++++++++++++------ napari_cellseg3d/code_models/crf.py | 8 ++- .../code_models/model_framework.py | 2 +- .../code_models/models/model_WNet.py | 18 +++--- .../code_models/models/wnet/model.py | 19 ++++-- napari_cellseg3d/code_models/workers.py | 62 ++++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 19 +++--- .../dev_scripts/correct_labels.py | 12 ++-- pyproject.toml | 1 + 9 files changed, 124 insertions(+), 78 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 4852f651..c67b3cab 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -2,9 +2,14 @@ import torch from numpy.random import PCG64, Generator -from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf +from napari_cellseg3d.code_models.crf import ( + CRFWorker, + correct_shape_for_crf, + crf_batch, + crf_with_config, +) from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss -from napari_cellseg3d.config import MODEL_LIST +from napari_cellseg3d.config import MODEL_LIST, CRFConfig rand_gen = Generator(PCG64(12345)) @@ -47,7 +52,38 @@ def test_soft_ncuts_loss(): assert 0 <= res <= 1 -def test_crf(qtbot): +def test_crf_batch(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_batch( + np.array([mock_image, mock_image, mock_image]), + np.array([mock_label, mock_label, mock_label]), + sa=config.sa, + sb=config.sb, + sg=config.sg, + w1=config.w1, + w2=config.w2, + ) + + assert isinstance(result, np.ndarray) + assert result.shape == (3, 2, dims, dims, dims) + + +def test_crf_config(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_with_config(mock_image, mock_label, config) + assert isinstance(result, np.ndarray) + assert result.shape == mock_label.shape + + +def test_crf_worker(qtbot): dims = 8 mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) @@ -60,20 +96,5 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=True - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() - - mock_image = mock_image[0] - mock_label = mock_label[0] - - crf = CRFWorker(mock_image, mock_label) - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=False - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() + result = next(crf._run_crf_job()) + on_yield(result) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 8c311059..b362246a 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,6 +54,8 @@ def correct_shape_for_crf(image, desired_dims=4): + logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}") + logger.debug(f"Image shape: {image.shape}") if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( @@ -62,6 +64,7 @@ def correct_shape_for_crf(image, desired_dims=4): image = np.squeeze(image, axis=0) elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) + logger.debug(f"Corrected image shape: {image.shape}") return image @@ -210,9 +213,12 @@ def _run_crf_job(self): if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - im = correct_shape_for_crf(self.labels[i]) + im = correct_shape_for_crf(self.images[i]) prob = correct_shape_for_crf(self.labels[i]) + logger.debug(f"image shape : {im.shape}") + logger.debug(f"labels shape : {prob.shape}") + yield crf( im, prob, diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 60644916..0296e0cf 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -281,7 +281,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth)", + filetype="Weights file (*.pth, *.pt)", ) if file[0] == self._default_weights_folder: return diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index cb5ef6d8..62142e73 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,8 +1,8 @@ # local -from napari_cellseg3d.code_models.models.wnet.model import WNet +from napari_cellseg3d.code_models.models.wnet.model import WNet_encoder -class WNet_(WNet): +class WNet_(WNet_encoder): use_default_training = False weights_file = "wnet.pth" @@ -24,13 +24,13 @@ def __init__( # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") - def forward(self, x): - """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference. - """ - return self.forward_encoder(x) - # enc = self.forward_encoder(x) - # return self.forward_decoder(enc) + # def forward(self, x): + # """Forward ENCODER pass of the W-Net model. + # Done this way to allow inference on the encoder only when called by sliding_window_inference. + # """ + # return self.forward_encoder(x) + # # enc = self.forward_encoder(x) + # # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 585ea0dd..a23084d0 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -16,6 +16,19 @@ ] +class WNet_encoder(nn.Module): + """WNet with encoder only.""" + + def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + super().__init__() + self.device = device + self.encoder = UNet(device, in_channels, num_classes, encoder=True) + + def forward(self, x): + """Forward pass of the W-Net model.""" + return self.forward_encoder(x) + + class WNet(nn.Module): """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. The model performs unsupervised segmentation of 3D images. @@ -36,13 +49,11 @@ def forward(self, x): def forward_encoder(self, x): """Forward pass of the encoder part of the W-Net model.""" - enc = self.encoder(x) - return enc + return self.encoder(x) def forward_decoder(self, enc): """Forward pass of the decoder part of the W-Net model.""" - dec = self.decoder(enc) - return dec + return self.decoder(enc) class UNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index dd9e38e3..8b3da42d 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -820,41 +820,43 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config - - # try: - self.log("Instantiating model...") - model = model_class( # FIXME test if works - input_img_size=[dims, dims, dims], - device=self.config.device, - num_classes=self.config.model_info.num_classes, - ) - # try: - model = model.to(self.config.device) - # except Exception as e: - # self.raise_error(e, "Issue loading model to device") - # logger.debug(f"model : {model}") - if model is None: - raise ValueError("Model is None") + if Path(weights_config.path).suffix == ".pt": + model = torch.jit.load(weights_config.path) # try: - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path else: - self.downloader.download_weights( - model_name, - model_class.weights_file, - ) - weights = str( - PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[dims, dims, dims], + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, + ) + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + ) - model.load_state_dict( # note that this is redefined in WNet_ - torch.load( - weights, - map_location=self.config.device, + model.load_state_dict( # note that this is redefined in WNet_ + torch.load( + weights, + map_location=self.config.device, + ) ) - ) - self.log("Done") + self.log("Done") # except Exception as e: # self.raise_error(e, "Issue loading weights") # except Exception as e: diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index e3ea55f5..74691e1f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,9 +1,9 @@ +from math import floor from pathlib import Path import napari import numpy as np from magicgui import magicgui -from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -44,7 +44,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") - self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) + self.image_layer_loader.layer_list.currentIndexChanged.connect( + self.auto_set_dims + ) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -140,10 +142,12 @@ def auto_set_dims(self): logger.debug(self.image_layer_loader.layer_name()) data = self.image_layer_loader.layer_data() if data is not None: - logger.debug("auto_set_dims : {}".format(data.shape)) + logger.debug(f"auto_set_dims : {data.shape}") if len(data.shape) == 3: for i, box in enumerate(self.crop_size_widgets): - logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") + logger.debug( + f"setting dim {i} to {floor(data.shape[i]/2)}" + ) box.setValue(floor(data.shape[i] / 2)) def _build(self): @@ -433,9 +437,8 @@ def _add_crop_sliders( box.value() for box in self.crop_size_widgets ] ############# - dims = [self._x, self._y, self._z] - [logger.debug(f"{dim}") for dim in dims] - logger.debug("SET DIMS ATTEMPT") + # [logger.debug(f"{dim}") for dim in dims] + # logger.debug("SET DIMS ATTEMPT") # if not self.create_new_layer.isChecked(): # self._x = x # self._y = y @@ -451,6 +454,8 @@ def _add_crop_sliders( # define crop sizes and boundaries for the image crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z] + # [logger.debug(f"{crop}") for crop in crop_sizes] + # logger.debug("SET CROP ATTEMPT") for i in range(len(crop_sizes)): if crop_sizes[i] > im1_stack.shape[i]: diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 4a7363b2..f413812d 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -363,9 +363,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") -# -# image_path = str(im_path / "volumes/images.tif") -# gt_labels_path = str(im_path / "labels/testing_im.tif") -# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) +if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/somatomotor") + + image_path = str(im_path / "volumes/c1images.tif") + gt_labels_path = str(im_path / "labels/c1labels.tif") + relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) diff --git a/pyproject.toml b/pyproject.toml index 7210af6e..87cc2e1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ docs = [ test = [ "pytest", "pytest_qt", + "pytest-cov", "coverage", "tox", "twine", From 0ef93ac0350e1141001200b4f9a27626e49859c6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:12:07 +0200 Subject: [PATCH 420/577] Update crf test/deps for testing --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/_tests/test_models.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index fa6905d5..0911e358 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,6 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions + python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index c67b3cab..ec7462db 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -68,7 +68,6 @@ def test_crf_batch(): w2=config.w2, ) - assert isinstance(result, np.ndarray) assert result.shape == (3, 2, dims, dims, dims) @@ -79,7 +78,6 @@ def test_crf_config(): config = CRFConfig() result = crf_with_config(mock_image, mock_label, config) - assert isinstance(result, np.ndarray) assert result.shape == mock_label.shape @@ -91,7 +89,6 @@ def test_crf_worker(qtbot): crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): - assert isinstance(result, np.ndarray) assert len(result.shape) == 4 assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] From 7882738140888c1d6413a5e174d2f199a04e8d46 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:20:30 +0200 Subject: [PATCH 421/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0911e358..d09be5f0 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -87,6 +86,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From fee1fccc2315a873211de7fe84798a553eb457ec Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:34:33 +0200 Subject: [PATCH 422/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d09be5f0..d36e03a3 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,6 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -86,7 +87,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From 8b64602a794f6a84969004fa45c326e51247ecd8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:28 +0200 Subject: [PATCH 423/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 0a7c07f0..ee033e59 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From b621da1849bab29a83bc5bf929eee0b742531e0f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:45 +0200 Subject: [PATCH 424/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d36e03a3..60bc5505 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox From caf306032412f2f559536b741d32892cd244a21b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:50:44 +0200 Subject: [PATCH 425/577] Trying to fix tox install of pydensecrf --- .github/workflows/test_and_deploy.yml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 60bc5505..e9a66ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions -# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/tox.ini b/tox.ini index ee033e59..ba3e8805 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From cb29858ebb708f66a7bb447d7d7e7d0057d246d2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:23:51 +0200 Subject: [PATCH 426/577] Added experimental ONNX support for inference --- .../code_models/model_framework.py | 15 ++++---- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 34 ++++++++++++++++++- .../code_plugins/plugin_model_inference.py | 14 +++++++- pyproject.toml | 8 +++++ 5 files changed, 64 insertions(+), 9 deletions(-) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 0296e0cf..f379ccb8 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -273,6 +273,14 @@ def get_available_models(): # self.lbl_model_path.setText(self.model_path) # # self.update_default() + def _update_weights_path(self, file): + if file[0] == self._default_weights_folder: + return + if file is not None and file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) + def _load_weights_path(self): """Show file dialog to set :py:attr:`model_path`""" @@ -283,12 +291,7 @@ def _load_weights_path(self): [self._default_weights_folder], filetype="Weights file (*.pth, *.pt)", ) - if file[0] == self._default_weights_folder: - return - if file is not None and file[0] != "": - self.weights_config.path = file[0] - self.weights_filewidget.text_field.setText(file[0]) - self._default_weights_folder = str(Path(file[0]).parent) + self._update_weights_path(file) @staticmethod def get_device(show=True): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index a23084d0..f98829bb 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -26,7 +26,7 @@ def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): def forward(self, x): """Forward pass of the W-Net model.""" - return self.forward_encoder(x) + return self.encoder(x) class WNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 8b3da42d..be88c835 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,6 +199,34 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files +class ONNXModelWrapper(torch.nn.Module): + """Class to replace torch model if ONNX is used""" + def __init__(self, file_location): + super().__init__() + try: + import onnx + import onnxruntime as ort + except ImportError as e: + logger.error("ONNX is not installed but ONNX model was loaded") + logger.error(e) + msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" + logger.error(msg) + raise ImportError(msg) + + self.ort_session = ort.InferenceSession( + file_location, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + + def forward(self, modeL_input): + outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + return torch.tensor(outputs[0]) + + def eval(self): + return True + + def to(self, device): + return True @dataclass class InferenceResult: @@ -821,9 +849,13 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config if Path(weights_config.path).suffix == ".pt": + self.log("Instantiating PyTorch jit model...") model = torch.jit.load(weights_config.path) # try: - else: + elif Path(weights_config.path).suffix == ".onnx": + self.log("Instantiating ONNX model...") + model = ONNXModelWrapper(weights_config.path) + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 74dc62e5..599ec5b3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING - +from pathlib import Path import numpy as np import pandas as pd @@ -348,6 +348,18 @@ def _toggle_display_window_size(self): """Show or hide window size choice depending on status of self.window_infer_box""" ui.toggle_visibility(self.window_infer_box, self.window_infer_params) + def _load_weights_path(self): + """Show file dialog to set :py:attr:`model_path`""" + + # logger.debug(self._default_weights_folder) + + file = ui.open_file_dialog( + self, + [self._default_weights_folder], + filetype="Weights file (*.pth, *.pt, *.onnx)", + ) + self._update_weights_path(file) + def _build(self): """Puts all widgets in a layout and adds them to the napari Viewer""" diff --git a/pyproject.toml b/pyproject.toml index 87cc2e1d..2783761e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,3 +118,11 @@ test = [ "twine", "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] +onnx-cpu = [ + "onnx", + "onnxruntime" +] +onnx-gpu = [ + "onnx", + "onnxruntime-gpu" +] From 4e1b0a8f42b3cbe70f9a14d9311e06512498522f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:47:48 +0200 Subject: [PATCH 427/577] Updated WNet for ONNX conversion --- .../code_models/models/wnet/model.py | 59 +++++++++++-------- napari_cellseg3d/code_models/workers.py | 9 ++- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index f98829bb..23584b30 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -59,18 +59,33 @@ def forward_decoder(self, enc): class UNet(nn.Module): """Half of the W-Net model, based on the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, encoder=True): + def __init__( + self, device, in_channels, out_channels, encoder=True, dropout=0.65 + ): super(UNet, self).__init__() self.device = device - self.in_b = InBlock(device, in_channels, 64) - self.conv1 = Block(device, 64, 128) - self.conv2 = Block(device, 128, 256) - self.conv3 = Block(device, 256, 512) - self.bot = Block(device, 512, 1024) - self.deconv1 = Block(device, 1024, 512) - self.deconv2 = Block(device, 512, 256) - self.deconv3 = Block(device, 256, 128) - self.out_b = OutBlock(device, 128, out_channels) + self.max_pool = nn.MaxPool3d(2) + self.in_b = InBlock(device, in_channels, 64, dropout=dropout) + self.conv1 = Block(device, 64, 128, dropout=dropout) + self.conv2 = Block(device, 128, 256, dropout=dropout) + self.conv3 = Block(device, 256, 512, dropout=dropout) + self.bot = Block(device, 512, 1024, dropout=dropout) + self.deconv1 = Block(device, 1024, 512, dropout=dropout) + self.conv_trans1 = nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + ) + self.deconv2 = Block(device, 512, 256, dropout=dropout) + self.conv_trans2 = nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + ) + self.deconv3 = Block(device, 256, 128, dropout=dropout) + self.conv_trans3 = nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + ) + self.out_b = OutBlock(device, 128, out_channels, dropout=dropout) + self.conv_trans_out = nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + ) self.sm = nn.Softmax(dim=1).to(device) self.encoder = encoder @@ -78,17 +93,15 @@ def __init__(self, device, in_channels, out_channels, encoder=True): def forward(self, x): """Forward pass of the U-Net model.""" in_b = self.in_b(x.to(self.device)) - c1 = self.conv1(nn.MaxPool3d(2)(in_b)) - c2 = self.conv2(nn.MaxPool3d(2)(c1)) - c3 = self.conv3(nn.MaxPool3d(2)(c2)) - x = self.bot(nn.MaxPool3d(2)(c3)) + c1 = self.conv1(self.max_pool(in_b)) + c2 = self.conv2(self.max_pool(c1)) + c3 = self.conv3(self.max_pool(c2)) + x = self.bot(self.max_pool(c3)) x = self.deconv1( torch.cat( [ c3, - nn.ConvTranspose3d( - 1024, 512, 2, stride=2, device=self.device - )(x), + self.conv_trans1(x), ], dim=1, ) @@ -97,9 +110,7 @@ def forward(self, x): torch.cat( [ c2, - nn.ConvTranspose3d( - 512, 256, 2, stride=2, device=self.device - )(x), + self.conv_trans2(x), ], dim=1, ) @@ -108,9 +119,7 @@ def forward(self, x): torch.cat( [ c1, - nn.ConvTranspose3d( - 256, 128, 2, stride=2, device=self.device - )(x), + self.conv_trans3(x), ], dim=1, ) @@ -119,9 +128,7 @@ def forward(self, x): torch.cat( [ in_b, - nn.ConvTranspose3d( - 128, 64, 2, stride=2, device=self.device - )(x), + self.conv_trans_out(x), ], dim=1, ) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index be88c835..bf6b8542 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -200,7 +200,7 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files class ONNXModelWrapper(torch.nn.Module): - """Class to replace torch model if ONNX is used""" + """Class to replace torch model by ONNX Runtime session""" def __init__(self, file_location): super().__init__() try: @@ -219,14 +219,17 @@ def __init__(self, file_location): ) def forward(self, modeL_input): + """Wraps ONNX output in a torch tensor""" outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) return torch.tensor(outputs[0]) def eval(self): - return True + """Dummy function to replace model.eval()""" + pass def to(self, device): - return True + """Dummy function to replace model.to(device)""" + pass @dataclass class InferenceResult: From a052d2cd4c470fcb141578dcb030830f904eb678 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:56:45 +0200 Subject: [PATCH 428/577] Added dropout param --- .../code_models/models/wnet/model.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 23584b30..3416acb1 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -141,17 +141,17 @@ def forward(self, x): class InBlock(nn.Module): """Input block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(InBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -163,19 +163,19 @@ def forward(self, x): class Block(nn.Module): """Basic block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(Block, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), nn.Conv3d(in_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.Conv3d(out_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -187,21 +187,21 @@ def forward(self, x): class OutBlock(nn.Module): """Output block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(OutBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, out_channels, 1, device=device), ).to(device) def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) + return self.module(x.to(self.device)) \ No newline at end of file From ea9115974294c52a033f0455b3e62510e603f927 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 31 May 2023 16:13:42 +0200 Subject: [PATCH 429/577] Minor fixes in training --- napari_cellseg3d/code_models/workers.py | 8 ++++---- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 +++- napari_cellseg3d/interface.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index bf6b8542..c67ea523 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1531,13 +1531,13 @@ def get_loader_func(num_samples): or epoch + 1 == self.config.max_epochs ): model.eval() + self.log("Performing validation...") with torch.no_grad(): for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) - self.log("Performing validation...") try: with torch.no_grad(): val_outputs = sliding_window_inference( @@ -1606,8 +1606,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1620,7 +1620,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 2a131a5f..3e666dcc 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -169,6 +169,8 @@ def __init__( self.validation_values = [] # self.model_choice.setCurrentIndex(0) + wnet_index = self.model_choice.findText("WNet") + self.model_choice.removeItem(wnet_index) ################################ # interface @@ -813,7 +815,7 @@ def start(self): ) self._set_worker_config() - self.worker = TrainingWorker(config=self.worker_config) + self.worker = TrainingWorker(worker_config=self.worker_config) self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index df00ad0b..e5b189ef 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1235,7 +1235,7 @@ def open_folder_dialog( logger.info(f"Default : {default_path}") return QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + "/.." + widget, "Open directory", default_path # + "/.." ) From f07d3eb581bf6fffdf6792ab481043c1a82c668d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:31:23 +0200 Subject: [PATCH 430/577] Fix weights file extension in inference + coverage - Remove unused scripts - More tests - Fixed weights type in inference --- .coveragerc | 7 + .gitignore | 1 + napari_cellseg3d/_tests/test_dock_widget.py | 1 + .../_tests/test_labels_correction.py | 8 +- .../_tests/test_plugin_inference.py | 2 + napari_cellseg3d/_tests/test_plugins.py | 21 ++ napari_cellseg3d/_tests/test_utils.py | 29 ++- .../code_models/model_framework.py | 28 +-- .../code_models/models/wnet/crf.py | 112 --------- napari_cellseg3d/code_plugins/plugin_crf.py | 6 +- .../code_plugins/plugin_metrics.py | 2 +- .../code_plugins/plugin_model_inference.py | 8 +- napari_cellseg3d/dev_scripts/convert.py | 26 -- napari_cellseg3d/dev_scripts/drafts.py | 15 -- .../dev_scripts/evaluate_labels.py | 2 +- .../extract_extra_channels_labels.py | 144 ----------- napari_cellseg3d/dev_scripts/view_brain.py | 8 - napari_cellseg3d/dev_scripts/view_sample.py | 29 --- .../dev_scripts/weight_conversion.py | 234 ------------------ napari_cellseg3d/interface.py | 6 +- napari_cellseg3d/utils.py | 2 +- tox.ini | 4 +- 22 files changed, 75 insertions(+), 620 deletions(-) create mode 100644 .coveragerc create mode 100644 napari_cellseg3d/_tests/test_plugins.py delete mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py delete mode 100644 napari_cellseg3d/dev_scripts/convert.py delete mode 100644 napari_cellseg3d/dev_scripts/drafts.py delete mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py delete mode 100644 napari_cellseg3d/dev_scripts/view_brain.py delete mode 100644 napari_cellseg3d/dev_scripts/view_sample.py delete mode 100644 napari_cellseg3d/dev_scripts/weight_conversion.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..038f3d5a --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[report] +exclude_lines = + if __name__ == .__main__.: + +[run] +omit = + napari_cellseg3d/setup.py diff --git a/.gitignore b/.gitignore index df67a187..7460d861 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,4 @@ notebooks/instance_test.ipynb !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif +cov.syspath.txt diff --git a/napari_cellseg3d/_tests/test_dock_widget.py b/napari_cellseg3d/_tests/test_dock_widget.py index 7737e540..8063c92b 100644 --- a/napari_cellseg3d/_tests/test_dock_widget.py +++ b/napari_cellseg3d/_tests/test_dock_widget.py @@ -11,6 +11,7 @@ def test_prepare(make_napari_viewer): viewer = make_napari_viewer() viewer.add_image(image) widget = Datamanager(viewer) + viewer.window.add_dock_widget(widget) widget.prepare(path_image, ".tif", "", False) diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..b4f13238 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -37,16 +37,16 @@ def test_correct_labels(): ) -def test_relabel(make_napari_viewer): - viewer = make_napari_viewer() +def test_relabel(): cl.relabel( str(image_path), str(labels_path), go_fast=True, - viewer=viewer, test=True, ) def test_evaluate_model_performance(): - el.evaluate_model_performance(labels, labels, print_details=True) + el.evaluate_model_performance( + labels, labels, print_details=True, visualize=False + ) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index ca8e84d4..1ae83102 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -57,3 +57,5 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) + + widget.on_yield(res) diff --git a/napari_cellseg3d/_tests/test_plugins.py b/napari_cellseg3d/_tests/test_plugins.py new file mode 100644 index 00000000..c58d26af --- /dev/null +++ b/napari_cellseg3d/_tests/test_plugins.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from napari_cellseg3d import plugins +from napari_cellseg3d.code_plugins import plugin_metrics as m + + +def test_all_plugins_import(make_napari_viewer): + plugins.napari_experimental_provide_dock_widget() + + +def test_plugin_metrics(make_napari_viewer): + viewer = make_napari_viewer() + w = m.MetricsUtils(viewer=viewer, parent=None) + viewer.window.add_dock_widget(w) + + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + labels_path = im_path + + w.image_filewidget.text_field = im_path + w.labels_filewidget.text_field = labels_path + w.compute_dice() diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 0b28183d..dc680b35 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,14 +1,15 @@ -import os from functools import partial +from pathlib import Path import numpy as np import torch from napari_cellseg3d import utils +from napari_cellseg3d.dev_scripts import thread_test def test_fill_list_in_between(): - list = [1, 2, 3, 4, 5, 6] + test_list = [1, 2, 3, 4, 5, 6] res = [ 1, "", @@ -30,11 +31,11 @@ def test_fill_list_in_between(): "", ] - assert utils.fill_list_in_between(list, 2, "") == res + assert utils.fill_list_in_between(test_list, 2, "") == res fill = partial(utils.fill_list_in_between, n=2, fill_value="") - assert fill(list) == res + assert fill(test_list) == res def test_align_array_sizes(): @@ -109,11 +110,19 @@ def test_normalize_x(): def test_parse_default_path(): - user_path = os.path.expanduser("~") - assert utils.parse_default_path([None]) == user_path + user_path = Path().home() + assert utils.parse_default_path([None]) == str(user_path) - path = ["C:/test/test", None, None] - assert utils.parse_default_path(path) == "C:/test/test" + test_path = "C:/test/test" + path = [test_path, None, None] + assert utils.parse_default_path(path) == test_path - path = ["C:/test/test", None, None, "D:/very/long/path/what/a/bore", ""] - assert utils.parse_default_path(path) == "D:/very/long/path/what/a/bore" + long_path = "D:/very/long/path/what/a/bore/ifonlytherewassomethingtohelpmenottypeitiallthetime" + path = [test_path, None, None, long_path, ""] + assert utils.parse_default_path(path) == long_path + + +def test_thread_test(make_napari_viewer): + viewer = make_napari_viewer() + w = thread_test.create_connected_widget(viewer) + viewer.window.add_dock_widget(w) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index f379ccb8..ddd9cd28 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -289,7 +289,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt)", + file_extension="Weights file (*.pth)", ) self._update_weights_path(file) @@ -311,31 +311,5 @@ def empty_cuda_cache(self): torch.cuda.empty_cache() logger.info("Attempt complete : Cache emptied") - # def update_default(self): # TODO add custom models - # """Update default path for smoother file dialogs, here with :py:attr:`~model_path` included""" - # - # if len(self.images_filepaths) != 0: - # from_images = str(Path(self.images_filepaths[0]).parent) - # else: - # from_images = None - # - # if len(self.labels_filepaths) != 0: - # from_labels = str(Path(self.labels_filepaths[0]).parent) - # else: - # from_labels = None - # - # possible_paths = [ - # path - # for path in [ - # from_images, - # from_labels, - # # self.model_path, - # self.results_path, - # ] - # if path is not None - # ] - # self._default_folders = possible_paths - # update if model_path is used again - def _build(self): raise NotImplementedError("Should be defined in children classes") diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py deleted file mode 100644 index 004db3a1..00000000 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Implements the CRF post-processing step for the W-Net. -Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - -Also uses research from: -Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials -Philipp Krähenbühl and Vladlen Koltun -NIPS 2011 - -Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. -""" - -import numpy as np -import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - create_pairwise_bilateral, - create_pairwise_gaussian, - unary_from_softmax, -) - -__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" -__credits__ = [ - "Yves Paychère", - "Colin Hofmann", - "Cyril Achard", - "Philipp Krähenbühl", - "Vladlen Koltun", - "Liang-Chieh Chen", - "George Papandreou", - "Iasonas Kokkinos", - "Kevin Murphy", - "Alan L. Yuille", - "Xide Xia", - "Brian Kulis", - "Lucas Beyer", -] - - -def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): - """CRF post-processing step for the W-Net, applied to a batch of images. - - Args: - images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. - probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. - """ - - return np.stack( - [ - crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) - for i in range(images.shape[0]) - ], - axis=0, - ) - - -def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): - """Implements the CRF post-processing step for the W-Net. - Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - Implemented using the pydensecrf library. - - Args: - image (np.ndarray): Array of shape (C, H, W, D) containing the input image. - prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. - """ - d = dcrf.DenseCRF( - image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] - ) - # print(f"Image shape : {image.shape}") - # print(f"Prob shape : {prob.shape}") - # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels - - # Get unary potentials from softmax probabilities - U = unary_from_softmax(prob) - d.setUnaryEnergy(U) - - # Generate pairwise potentials - featsGaussian = create_pairwise_gaussian( - sdims=(sg, sg, sg), shape=image.shape[1:] - ) # image.shape) - featsBilateral = create_pairwise_bilateral( - sdims=(sa, sa, sa), - schan=tuple([sb for i in range(image.shape[0])]), - img=image, - chdim=-1, - ) - - # Add pairwise potentials to the CRF - compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( - [1 for i in range(prob.shape[0])] - # , dtype=np.float32 - ) - d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) - d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) - - # Run inference - Q = d.inference(n_iter) - - return np.array(Q).reshape( - (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) - ) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index d8407a0f..76194e87 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial from pathlib import Path @@ -277,7 +278,10 @@ def _on_start(self): def _on_finish(self): self.worker = None - self.start_button.setText("Start") + with contextlib.suppress(RuntimeError): + self.start_button.setText("Start") + + # should only happen when testing def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 2a6e713c..1dc5e7de 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -23,7 +23,7 @@ class MetricsUtils(BasePluginFolder): """Plugin to evaluate metrics between two sets of labels, ground truth and prediction""" - def __init__(self, viewer: "napari.viewer.Viewer", parent): + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): """Creates a MetricsUtils widget for computing and plotting dice metrics between labels. Args: viewer: viewer to display the widget in diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 599ec5b3..256cffa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING -from pathlib import Path + import numpy as np import pandas as pd @@ -171,7 +171,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default + self.window_size_choice.setCurrentIndex( + self._default_window_size + ) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -356,7 +358,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt, *.onnx)", + file_extension="Weights file (*.pth *.pt *.onnx)", ) self._update_weights_path(file) diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py deleted file mode 100644 index 641de627..00000000 --- a/napari_cellseg3d/dev_scripts/convert.py +++ /dev/null @@ -1,26 +0,0 @@ -import glob -import os - -import numpy as np -from tifffile import imread, imwrite - -# input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" -# output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" - -input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/labels" -output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/lab_sem" - -filenames = [] -paths = [] -filetype = ".tif" -for filename in glob.glob(os.path.join(input_seg_path, "*" + filetype)): - paths.append(filename) - filenames.append(os.path.basename(filename)) - # print(os.path.basename(filename)) -for file in paths: - image = imread(file) - - image[image >= 1] = 1 - image = image.astype(np.uint16) - - imwrite(output_seg_path + "/" + os.path.basename(file), image) diff --git a/napari_cellseg3d/dev_scripts/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py deleted file mode 100644 index cdd02256..00000000 --- a/napari_cellseg3d/dev_scripts/drafts.py +++ /dev/null @@ -1,15 +0,0 @@ -import napari -import numpy as np -from magicgui import magicgui -from napari.types import ImageData, LabelsData - - -@magicgui(call_button="Run Threshold") -def threshold(image: ImageData, threshold: int = 75) -> LabelsData: - """Threshold an image and return a mask.""" - return (image > threshold).astype(int) - - -viewer = napari.view_image(np.random.randint(0, 100, (64, 64))) -viewer.window.add_dock_widget(threshold) -threshold() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 26b45d3f..00bce5ec 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -134,7 +134,7 @@ def evaluate_model_performance( log.info(mean_ratio_false_pixel_artefact) if visualize: - viewer = napari.Viewer() + viewer = napari.Viewer(ndisplay=3) viewer.add_labels(labels, name="ground truth") viewer.add_labels(model_labels, name="model's labels") found_model = np.where( diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py deleted file mode 100644 index 70ee10b6..00000000 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ /dev/null @@ -1,144 +0,0 @@ -import numpy as np -from skimage.filters import threshold_otsu -from skimage.segmentation import expand_labels -from tqdm import tqdm - - -def extract_labels_from_channels( # TODO add separate channels results - nuclei_labels: np.array, - extra_channels: list, - radius: int = 4, - threshold_factor=2, - viewer=None, -): - """ - Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. - Args: - nuclei_labels (np.array): labels for the nuclei - extra_channels (list): channels arrays to extract labels from - radius: radius in which the approximation is made - - Returns: - A list of extracted labels for each extra channel - """ - labeled_channels = [] - contrasted_channels = [] - for channel in extra_channels: - channel = (channel - np.min(channel)) / ( - np.max(channel) - np.min(channel) - ) - threshold_brightness = threshold_otsu(channel) * threshold_factor - channel_contrasted = np.where( - channel > threshold_brightness, channel, 0 - ) - contrasted_channels.append(channel_contrasted) - if viewer is not None: - viewer.add_image( - channel_contrasted, - name="channel_contrasted", - colormap="viridis", - ) - for label_id in tqdm(np.unique(nuclei_labels)): - if label_id == 0: - continue - label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) - expanded = expand_labels(label_nucleus, distance=radius) - restricted = np.where(expanded != 0, nuclei_labels, 0) - overlap = np.where(restricted != label_id, restricted, 0) - - for i, channel in enumerate(contrasted_channels): - label_contrasted = np.where(expanded != 0, channel, 0) - if overlap.any() != 0: - max_labeled = 0 - for overlap_id in np.unique(overlap): - if overlap_id == 0: - continue - assigned_pixels = np.count_nonzero( - np.where(overlap == overlap_id, channel, 0) - ) - if assigned_pixels > max_labeled: - max_labeled = assigned_pixels - max_label_id = overlap_id - if label_id != max_label_id: - labeled_channels.append( - np.zeros_like(label_contrasted) - ) - else: - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels.append(labeled_channel) - if ( - np.count_nonzero(labeled_channel) > 0 - and viewer is not None - ): - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) - - cat_labels = np.zeros_like(nuclei_labels) - for labels in np.unique(labeled_channels): - if labels == 0: - continue - cat_labels += np.where(labels != 0, labels, 0) - return cat_labels - - -if __name__ == "__main__": - from pathlib import Path - - import napari - from tifffile import imread - - image_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - ) - # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - nuclei_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" - ) - extra_channels_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" - ) - extra_channels = [ - imread(str(path)) - for path in extra_channels_path.glob( - "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" - ) - ] - labels = imread(str(image_path)) - viewer = napari.Viewer() - - shift = 0 - viewer.add_image( - imread(str(nuclei_path))[ - shift : 32 + shift, shift : 32 + shift, shift : 32 + shift - ], - name="nuclei", - ) - viewer.add_labels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - [ - viewer.add_image( - channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - for channel in extra_channels - ] - - labeled_channels = extract_labels_from_channels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], - [ - c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - for c in extra_channels - ], - radius=4, - viewer=viewer, - ) - - viewer.add_labels(labeled_channels) - # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] - # expanded = expand_labels(labels, 4) - # viewer.add_labels(expanded) - napari.run() diff --git a/napari_cellseg3d/dev_scripts/view_brain.py b/napari_cellseg3d/dev_scripts/view_brain.py deleted file mode 100644 index 145d4e45..00000000 --- a/napari_cellseg3d/dev_scripts/view_brain.py +++ /dev/null @@ -1,8 +0,0 @@ -import napari -from tifffile import imread - -y = imread("/Users/maximevidal/Documents/3drawdata/wholebrain.tif") - -with napari.gui_qt(): - viewer = napari.Viewer() - viewer.add_image(y, contrast_limits=[0, 2000], multiscale=False) diff --git a/napari_cellseg3d/dev_scripts/view_sample.py b/napari_cellseg3d/dev_scripts/view_sample.py deleted file mode 100644 index 8e87f85c..00000000 --- a/napari_cellseg3d/dev_scripts/view_sample.py +++ /dev/null @@ -1,29 +0,0 @@ -import napari -from tifffile import imread - -# Visual -x = imread( - "/Users/maximevidal/Documents/trailmap/data/no-edge-validation/visual-original/volumes/images.tif" -) -y_semantic = imread( - "/Users/maximevidal/Documents/trailmap/data/testing/seg-visual1-single/image.tif" -) -y_instance = imread( - "/Users/maximevidal/Documents/trailmap/data/instance-testing/test-visual-5.tiff" -) -y_true = imread( - "/Users/maximevidal/Documents/3drawdata/visual/labels/labels.tif" -) - -# SM -# x = imread("/Users/maximevidal/Documents/trailmap/data/no-edge-validation/validation-original/volumes/c5images.tif") -# y = imread("/Users/maximevidal/Documents/trailmap/data/instance-testing/test1.tiff") -# y_true = imread("/Users/maximevidal/Documents/3drawdata/somatomotor/labels/c5labels.tif") - -with napari.gui_qt(): - viewer = napari.view_image( - x, colormap="inferno", contrast_limits=[200, 1000] - ) - viewer.add_image(y_semantic, name="semantic_predictions", opacity=0.5) - viewer.add_labels(y_instance, name="instance_predictions", seed=0.6) - viewer.add_labels(y_true, name="truth", seed=0.6) diff --git a/napari_cellseg3d/dev_scripts/weight_conversion.py b/napari_cellseg3d/dev_scripts/weight_conversion.py deleted file mode 100644 index 6cdb9c43..00000000 --- a/napari_cellseg3d/dev_scripts/weight_conversion.py +++ /dev/null @@ -1,234 +0,0 @@ -import collections -import os - -import torch - -from napari_cellseg3d.code_models.models import get_net -from napari_cellseg3d.code_models.models.unet.model import UNet3D - -# not sure this actually works when put here - - -def weight_translate(k, w): - k = key_translate(k) - if k.endswith(".weight"): - if w.dim() == 2: - w = w.t() - elif w.dim() == 1: - pass - elif w.dim() == 4: - w = w.permute(3, 2, 0, 1) - else: - assert w.dim() == 5 - w = w.permute(4, 3, 0, 1, 2) - return w - - -def key_translate(k): - k = ( - k.replace( - "conv3d/kernel:0", - "encoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization/gamma:0", - "encoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization/beta:0", - "encoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_1/kernel:0", - "encoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_1/gamma:0", - "encoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_1/beta:0", - "encoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_2/kernel:0", - "encoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_2/gamma:0", - "encoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_2/beta:0", - "encoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_3/kernel:0", - "encoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_3/gamma:0", - "encoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_3/beta:0", - "encoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_4/kernel:0", - "encoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_4/gamma:0", - "encoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_4/beta:0", - "encoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_5/kernel:0", - "encoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_5/gamma:0", - "encoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_5/beta:0", - "encoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_6/kernel:0", - "encoders.3.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_6/gamma:0", - "encoders.3.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_6/beta:0", - "encoders.3.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_7/kernel:0", - "encoders.3.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_7/gamma:0", - "encoders.3.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_7/beta:0", - "encoders.3.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_8/kernel:0", - "decoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_8/gamma:0", - "decoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_8/beta:0", - "decoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_9/kernel:0", - "decoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_9/gamma:0", - "decoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_9/beta:0", - "decoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_10/kernel:0", - "decoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_10/gamma:0", - "decoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_10/beta:0", - "decoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_11/kernel:0", - "decoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_11/gamma:0", - "decoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_11/beta:0", - "decoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_12/kernel:0", - "decoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_12/gamma:0", - "decoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_12/beta:0", - "decoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_13/kernel:0", - "decoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_13/gamma:0", - "decoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_13/beta:0", - "decoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace("conv3d_14/kernel:0", "final_conv.weight") - .replace("conv3d_14/bias:0", "final_conv.bias") - ) - return k - - -model = get_net() -base_path = os.path.abspath(__file__ + "/..") -weights_path = base_path + "/data/model-weights/trailmap_model.hdf5" -model.load_weights(weights_path) - -for i, l in enumerate(model.layers): - print(i, l) - print( - "L{}: {}".format( - i, ", ".join(str(w.shape) for w in model.layers[i].weights) - ) - ) - -weights_pt = collections.OrderedDict( - [(w.name, torch.from_numpy(w.numpy())) for w in model.trainable_variables] -) -torch.save(weights_pt, base_path + "/data/model-weights/trailmaptorch.pt") -torch_weights = torch.load(base_path + "/data/model-weights/trailmaptorch.pt") -param_dict = { - key_translate(k): weight_translate(k, v) for k, v in torch_weights.items() -} - -trailmap_model = UNet3D(1, 1) -torchparam = trailmap_model.state_dict() -for k, v in torchparam.items(): - print("{:20s} {}".format(k, v.shape)) - -trailmap_model.load_state_dict(param_dict, strict=False) -torch.save( - trailmap_model.state_dict(), - base_path + "/data/model-weights/trailmaptorchpretrained.pt", -) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index e5b189ef..6a73eba0 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1207,7 +1207,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, possible_paths: list = (), - filetype: str = "Image file (*.tif *.tiff)", + file_extension: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1216,14 +1216,14 @@ def open_file_dialog( possible_paths (str): Paths that may have been chosen before, can be a string or an array of strings containing the paths load_as_folder (bool): Whether to open a folder or a single file. If True, will allow opening folder as a single file (2D stack interpreted as 3D) - filetype (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` + file_extension (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` """ default_path = utils.parse_default_path(possible_paths) return QFileDialog.getOpenFileName( - widget, "Choose file", default_path, filetype + widget, "Choose file", default_path, file_extension ) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 90a64cfb..663872c4 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -520,7 +520,7 @@ def parse_default_path(possible_paths): # ] print(default_paths) if len(default_paths) == 0: - return str(Path.home()) + return str(Path().home()) default_path = max(default_paths, key=len) return str(default_path) diff --git a/tox.ini b/tox.ini index ba3e8805..0605fc8c 100644 --- a/tox.ini +++ b/tox.ini @@ -38,5 +38,7 @@ deps = qtpy git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] - +; opencv-python +extras = crf +usedevelop = true commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml From f854bf3a0b3883cb049784fb215cd9b3ee06ba6a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:41:07 +0200 Subject: [PATCH 431/577] Run all hooks --- .../_tests/test_plugin_inference.py | 5 ++++- .../code_models/models/model_TRAILMAP.py | 15 +++++--------- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 20 +++++++++++-------- napari_cellseg3d/code_plugins/plugin_base.py | 15 ++++++-------- .../code_plugins/plugin_helper.py | 4 +++- .../code_plugins/plugin_utilities.py | 5 ++++- napari_cellseg3d/dev_scripts/thread_test.py | 6 ++++-- pyproject.toml | 2 +- 9 files changed, 40 insertions(+), 34 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1ae83102..1e486c14 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,9 +34,12 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() + widget.model_choice.setCurrentIndex(-1) + assert widget.window_infer_box.isChecked() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + widget.model_choice.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 8a108e37..e6bbad55 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -39,13 +39,12 @@ def forward(self, x): up8 = self.up8(torch.cat([up7, conv0], 1)) # l1 # print(up8.shape) - out = self.out(up8) + return self.out(up8) # print("out:") # print(out.shape) - return out def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -56,10 +55,9 @@ def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.ReLU(), nn.MaxPool3d(2), ) - return encode def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -69,10 +67,9 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.BatchNorm3d(out_ch), nn.ReLU(), ) - return encode def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - decode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -85,13 +82,11 @@ def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): out_ch, out_ch, kernel_size=kernel_size, stride=(2, 2, 2) ), ) - return decode def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - out = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) - return out class TRAILMAP_(TRAILMAP): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 3416acb1..2900b89c 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -204,4 +204,4 @@ def __init__(self, device, in_channels, out_channels, dropout=0.65): def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) \ No newline at end of file + return self.module(x.to(self.device)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c67ea523..245e6f02 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,28 +199,31 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files + class ONNXModelWrapper(torch.nn.Module): """Class to replace torch model by ONNX Runtime session""" + def __init__(self, file_location): super().__init__() try: - import onnx import onnxruntime as ort except ImportError as e: logger.error("ONNX is not installed but ONNX model was loaded") logger.error(e) msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" logger.error(msg) - raise ImportError(msg) + raise ImportError(msg) from e self.ort_session = ort.InferenceSession( file_location, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) def forward(self, modeL_input): """Wraps ONNX output in a torch tensor""" - outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + outputs = self.ort_session.run( + None, {"input": modeL_input.cpu().numpy()} + ) return torch.tensor(outputs[0]) def eval(self): @@ -231,6 +234,7 @@ def to(self, device): """Dummy function to replace model.to(device)""" pass + @dataclass class InferenceResult: """Class to record results of a segmentation job""" @@ -858,7 +862,7 @@ def inference(self): elif Path(weights_config.path).suffix == ".onnx": self.log("Instantiating ONNX model...") model = ONNXModelWrapper(weights_config.path) - else: # assume is .pth + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], @@ -1606,8 +1610,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1620,7 +1624,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 26da7a42..cfa3f0d7 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -227,17 +227,16 @@ def _show_filetype_choice(self): def _show_file_dialog(self): """Open file dialog and process path depending on single file/folder loading behaviour""" if self.load_as_stack_choice.isChecked(): - folder = ui.open_folder_dialog( + choice = ui.open_folder_dialog( self, self._default_path, filetype=f"Image file (*{self.filetype_choice.currentText()})", ) - return folder else: f_name = ui.open_file_dialog(self, self._default_path) - f_name = str(f_name[0]) - self.filetype = str(Path(f_name).suffix) - return f_name + choice = str(f_name[0]) + self.filetype = str(Path(choice).suffix) + return choice def _show_dialog_images(self): """Show file dialog and set image path""" @@ -291,16 +290,14 @@ def _make_close_button(self): return btn def _make_prev_button(self): - btn = ui.Button( + return ui.Button( "Previous", lambda: self.setCurrentIndex(self.currentIndex() - 1) ) - return btn def _make_next_button(self): - btn = ui.Button( + return ui.Button( "Next", lambda: self.setCurrentIndex(self.currentIndex() + 1) ) - return btn def remove_from_viewer(self): """Removes the widget from the napari window. diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index a3fd8c0d..54c34a8f 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -1,6 +1,8 @@ import pathlib +from typing import TYPE_CHECKING -import napari +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import QSize diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 868dd279..6e1a606a 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -1,4 +1,7 @@ -import napari +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import qInstallMessageHandler diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 20668125..a48f6db0 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -1,8 +1,8 @@ import time import napari -import numpy as np from napari.qt.threading import thread_worker +from numpy.random import PCG64, Generator from qtpy.QtWidgets import ( QGridLayout, QLabel, @@ -13,6 +13,8 @@ QWidget, ) +rand_gen = Generator(PCG64(12345)) + @thread_worker def two_way_communication_with_args(start, end): @@ -129,7 +131,7 @@ def on_finish(): if __name__ == "__main__": - viewer = napari.view_image(np.random.rand(512, 512)) + viewer = napari.view_image(rand_gen.random(512, 512)) w = create_connected_widget(viewer) viewer.window.add_dock_widget(w) diff --git a/pyproject.toml b/pyproject.toml index 2783761e..f71ddb23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ select = [ ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) # and 'G004' (do not use f-strings in logging) -ignore = ["E501", "E741", "G004"] +ignore = ["E501", "E741", "G004", "A003"] exclude = [ ".bzr", ".direnv", From c55dfac9d9bdb366400efa221cb19b9ca459a425 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 11:27:58 +0200 Subject: [PATCH 432/577] Fix inference testing --- .../_tests/test_plugin_inference.py | 13 +++++++----- .../code_models/models/model_test.py | 20 +++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1e486c14..779f5094 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,12 +34,15 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - widget.model_choice.setCurrentIndex(-1) + widget.model_choice.setCurrentText("WNet") + widget._restrict_window_size_for_model() assert widget.window_infer_box.isChecked() + assert widget.window_size_choice.currentText() == "64" - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(-1) + test_model_name = "test" + MODEL_LIST[test_model_name] = TestModel + widget.model_choice.addItem(test_model_name) + widget.model_choice.setCurrentText(test_model_name) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None @@ -59,6 +62,6 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) - assert res.result.shape == (6, 6, 6) + assert res.result.shape == (8, 8, 8) widget.on_yield(res) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 1cb52f06..28f3a05b 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -20,13 +20,13 @@ def forward(self, x): # return val_inputs -# if __name__ == "__main__": -# -# model = TestModel() -# model.train() -# model.zero_grad() -# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR -# torch.save( -# model.state_dict(), -# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" -# ) +if __name__ == "__main__": + model = TestModel() + model.train() + model.zero_grad() + from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR + + torch.save( + model.state_dict(), + PRETRAINED_WEIGHTS_DIR + f"/{TestModel.weights_file}", + ) From dba9d8ece668e281657df3616031cc76b1eb6c12 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 13:45:50 +0200 Subject: [PATCH 433/577] Changed anisotropy calculation --- napari_cellseg3d/_tests/test_interface.py | 8 +++++++- napari_cellseg3d/interface.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py index be811721..08e0e675 100644 --- a/napari_cellseg3d/_tests/test_interface.py +++ b/napari_cellseg3d/_tests/test_interface.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.interface import Log +from napari_cellseg3d.interface import AnisotropyWidgets, Log def test_log(qtbot): @@ -12,3 +12,9 @@ def test_log(qtbot): assert log.toPlainText() == "\ntest2" qtbot.add_widget(log) + + +def test_zoom_factor(): + resolution = [10.0, 10.0, 5.0] + zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution) + assert zoom == [1, 1, 0.5] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 6a73eba0..57d78795 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -734,8 +734,8 @@ def anisotropy_zoom_factor(aniso_res): """ - base = min(aniso_res) - return [base / res for res in aniso_res] + base = max(aniso_res) + return [res / base for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" From 6c1b33b065f858f7774b9eb3dd0f14b92b0c9103 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 11:37:13 +0200 Subject: [PATCH 434/577] Finish rebase + bump version --- napari_cellseg3d/__init__.py | 2 +- .../code_models/instance_segmentation.py | 105 +++++++----------- .../code_plugins/plugin_helper.py | 2 +- .../dev_scripts/evaluate_labels.py | 81 +++++++++++++- pyproject.toml | 3 +- setup.cfg | 2 +- 6 files changed, 119 insertions(+), 76 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 6e2681e8..be8123e4 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc1" +__version__ = "0.0.3rc1" diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 93de0768..f5066ebe 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -9,9 +9,6 @@ from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui @@ -110,42 +107,6 @@ def run_method_on_channels(self, image): return result.squeeze() -class InstanceMethod: - def __init__( - self, - name: str, - function: callable, - num_sliders: int, - num_counters: int, - ): - self.name = name - self.function = function - self.counters: List[ui.DoubleIncrementCounter] = [] - self.sliders: List[ui.Slider] = [] - if num_sliders > 0: - for i in range(num_sliders): - widget = f"slider_{i}" - setattr( - self, - widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label=""), - ) - self.sliders.append(getattr(self, widget)) - - if num_counters > 0: - for i in range(num_counters): - widget = f"counter_{i}" - setattr( - self, - widget, - ui.DoubleIncrementCounter(label=""), - ) - self.counters.append(getattr(self, widget)) - - def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") - - @dataclass class ImageStats: volume: List[float] @@ -186,7 +147,7 @@ def voronoi_otsu( volume: np.ndarray, spot_sigma: float, outline_sigma: float, - remove_small_size: float, + # remove_small_size: float, ): """ Voronoi-Otsu labeling from pyclesperanto. @@ -202,12 +163,13 @@ def voronoi_otsu( Instance segmentation labels from Voronoi-Otsu method """ - semantic = np.squeeze(volume) + # remove_small_size (float): remove all objects smaller than the specified size in pixels + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) instance = cle.voronoi_otsu_labeling( - semantic, spot_sigma=spot_sigma, outline_sigma=outline_sigma + volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) # instance = remove_small_objects(instance, remove_small_size) return np.array(instance) @@ -225,8 +187,6 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) - """ logger.debug( f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" @@ -445,13 +405,16 @@ def sphericity(region): ) -class Watershed(InstanceMethod, metaclass=Singleton): - def __init__(self): +class Watershed(InstanceMethod): + """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" + + def __init__(self, widget_parent=None): super().__init__( - name="Watershed", + name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -488,13 +451,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): - def __init__(self): +class ConnectedComponents(InstanceMethod): + """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" + + def __init__(self, widget_parent=None): super().__init__( - name="Connected Components", + name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -516,33 +482,37 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): - def __init__(self): +class VoronoiOtsu(InstanceMethod): + """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" + + def __init__(self, widget_parent=None): super().__init__( - name="Voronoi-Otsu", + name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, - num_counters=3, + num_counters=2, + widget_parent=widget_parent, ) - self.counters[0].label.setText("Spot sigma") + self.counters[0].label.setText("Spot sigma") # closeness self.counters[ 0 ].tooltips = "Determines how close detected objects can be" self.counters[0].setMaximum(100) self.counters[0].setValue(2) - self.counters[1].label.setText("Outline sigma") + self.counters[1].label.setText("Outline sigma") # smoothness self.counters[ 1 ].tooltips = "Determines the smoothness of the segmentation" self.counters[1].setMaximum(100) self.counters[1].setValue(2) - self.counters[2].label.setText("Small object removal") - self.counters[2].tooltips = ( - "Volume/size threshold for small object removal." - "\nAll objects with a volume/size below this value will be removed." - ) + # self.counters[2].label.setText("Small object removal") + # self.counters[2].tooltips = ( + # "Volume/size threshold for small object removal." + # "\nAll objects with a volume/size below this value will be removed." + # ) + # self.counters[2].setValue(30) def run_method(self, image): ################ @@ -557,7 +527,7 @@ def run_method(self, image): image, self.counters[0].value(), self.counters[1].value(), - self.counters[2].value(), + # self.counters[2].value(), ) @@ -575,7 +545,6 @@ def __init__(self, parent=None): """ super().__init__(parent) - self.method_choice = ui.DropdownMenu( list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) ) @@ -588,7 +557,6 @@ def __init__(self, parent=None): self._build() def _build(self): - group = ui.GroupedWidget("Instance segmentation") group.layout.addWidget(self.method_choice) @@ -620,6 +588,9 @@ def _set_visibility(self): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) + else: + for widget in self.instance_widgets[name]: + widget.set_visibility(True) def run_method(self, volume): """ @@ -636,7 +607,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, - VoronoiOtsu().name: VoronoiOtsu, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 54c34a8f..552f70ea 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -39,7 +39,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.3rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 00bce5ec..64fbaf5e 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,5 @@ import napari import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm @@ -128,9 +126,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: @@ -190,6 +186,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file diff --git a/pyproject.toml b/pyproject.toml index f71ddb23..e39a7522 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "napari_cellseg3d" -version = "0.0.2rc6" +version = "0.0.3rc1" authors = [ {name = "Cyril Achard", email = "cyril.achard@epfl.ch"}, {name = "Maxime Vidal", email = "maxime.vidal@epfl.ch"}, @@ -102,6 +102,7 @@ dev = [ "black", "ruff", "pre-commit", + "tuna", ] docs = [ "sphinx", diff --git a/setup.cfg b/setup.cfg index f3294b60..8ee82f96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc6 +version = 0.0.3rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 556f132a52adf9b529a55b822049a59a9483e21c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 435/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- .github/workflows/test_and_deploy.yml | 4 -- .../_tests/test_plugin_inference.py | 1 + napari_cellseg3d/code_models/workers.py | 19 ++---- .../code_plugins/plugin_convert.py | 29 ++++---- .../code_plugins/plugin_model_inference.py | 68 +++++++++++++++++-- napari_cellseg3d/config.py | 28 ++++---- napari_cellseg3d/interface.py | 59 ++++++++-------- requirements.txt | 3 +- 8 files changed, 129 insertions(+), 82 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index e9a66ae2..bb3662e8 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -8,16 +8,12 @@ on: branches: - main - npe2 - - cy/voronoi-otsu - - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - npe2 - - cy/voronoi-otsu - - cy/wnet workflow_dispatch: jobs: diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 779f5094..68ab2067 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -14,6 +14,7 @@ from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 245e6f02..c588ce14 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -53,11 +53,9 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.crf import crf_with_config -from napari_cellseg3d.code_models.instance_segmentation import ( - ImageStats, - volume_stats, -) +from napari_cellseg3d import utils +from napari_cellseg3d.code_models.model_instance_seg import ImageStats +from napari_cellseg3d.code_models.model_instance_seg import volume_stats logger = utils.LOGGER @@ -679,15 +677,8 @@ def instance_seg( if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance.method - instance_labels = method.run_method_on_channels(semantic_labels) - self.log(f"DEBUG instance results shape : {instance_labels.shape}") - - filetype = ( - ".tif" - if self.config.filetype == "" - else "_" + self.config.filetype - ) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 4357e51e..b96727d0 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -13,6 +13,9 @@ threshold, to_semantic, ) +from napari_cellseg3d.code_models.model_instance_seg import threshold +from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder MAX_W = ui.UTILS_MAX_WIDTH @@ -276,20 +279,18 @@ def _start(self): utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) - + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): """ diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 256cffa4..33e15543 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -15,13 +15,13 @@ InstanceWidgets, ) from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.workers import ( - InferenceResult, - InferenceWorker, +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, ) -from napari_cellseg3d.code_plugins.plugin_crf import CRFParamsWidget - -logger = utils.LOGGER class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -581,7 +581,61 @@ def start(self): self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - self._set_worker_config() + self.model_info = config.ModelInfo( + name=self.model_choice.currentText(), + model_input_size=self.model_input_size.value(), + ) + + self.weights_config.custom = self.custom_weights_choice.isChecked() + + save_path = self.results_filewidget.text_field.text() + if not self._check_results_path(save_path): + msg = f"ERROR: please set valid results path. Current path is {save_path}" + self.log.print_and_log(msg) + warnings.warn(msg) + else: + if self.results_path is None: + self.results_path = save_path + + zoom_config = config.Zoom( + enabled=self.anisotropy_wdgt.enabled(), + zoom_values=self.anisotropy_wdgt.scaling_xyz(), + ) + thresholding_config = config.Thresholding( + enabled=self.thresholding_checkbox.isChecked(), + threshold_value=self.thresholding_slider.slider_value, + ) + + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] + + self.post_process_config = config.PostProcessConfig( + zoom=zoom_config, + thresholding=thresholding_config, + instance=self.instance_config, + ) + + if self.window_infer_box.isChecked(): + size = int(self.window_size_choice.currentText()) + window_config = config.SlidingWindowConfig( + window_size=size, + window_overlap=self.window_overlap_slider.slider_value, + ) + else: + window_config = config.SlidingWindowConfig() + + self.worker_config = config.InferenceWorkerConfig( + device=self.get_device(), + model_info=self.model_info, + weights_config=self.weights_config, + results_path=self.results_path, + filetype=self.filetype_choice.currentText(), + keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), + compute_stats=self.save_stats_to_csv_box.isChecked(), + post_process_config=self.post_process_config, + sliding_window_config=window_config, + ) ##################### ##################### ##################### diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 5c0b34be..c082407e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,14 +6,20 @@ import napari import numpy as np -from napari_cellseg3d.code_models.instance_segmentation import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP -from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ -from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ -from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ -from napari_cellseg3d.code_models.models.model_VNet import VNet_ -from napari_cellseg3d.code_models.models.model_WNet import WNet_ +from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet +from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR +from napari_cellseg3d.code_models.models import ( + model_TRAILMAP_MS as TRAILMAP_MS, +) +from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -31,7 +37,8 @@ # "test" : DO NOT USE, reserved for testing } -PRETRAINED_WEIGHTS_DIR = str( + +WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") ) @@ -117,11 +124,6 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: InstanceMethod = None - @dataclass class PostProcessConfig: """Class to record params for post processing @@ -134,7 +136,7 @@ class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None @dataclass diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 57d78795..61764c84 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -9,30 +9,33 @@ from qtpy import QtCore # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject, Qt, QUrl -from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor -from qtpy.QtWidgets import ( - QCheckBox, - QComboBox, - QDoubleSpinBox, - QFileDialog, - QGridLayout, - QGroupBox, - QHBoxLayout, - QLabel, - QLayout, - QLineEdit, - QMenu, - QPushButton, - QRadioButton, - QScrollArea, - QSizePolicy, - QSlider, - QSpinBox, - QTextEdit, - QVBoxLayout, - QWidget, -) +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt +# from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QUrl +from qtpy.QtGui import QCursor +from qtpy.QtGui import QDesktopServices +from qtpy.QtGui import QTextCursor +from qtpy.QtWidgets import QCheckBox +from qtpy.QtWidgets import QComboBox +from qtpy.QtWidgets import QDoubleSpinBox +from qtpy.QtWidgets import QFileDialog +from qtpy.QtWidgets import QGridLayout +from qtpy.QtWidgets import QGroupBox +from qtpy.QtWidgets import QHBoxLayout +from qtpy.QtWidgets import QLabel +from qtpy.QtWidgets import QLayout +from qtpy.QtWidgets import QLineEdit +from qtpy.QtWidgets import QMenu +from qtpy.QtWidgets import QPushButton +from qtpy.QtWidgets import QRadioButton +from qtpy.QtWidgets import QScrollArea +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QSlider +from qtpy.QtWidgets import QSpinBox +from qtpy.QtWidgets import QTextEdit +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # Local from napari_cellseg3d import utils @@ -525,10 +528,10 @@ def __init__( def set_visibility(self, visible: bool): self.container.setVisible(visible) self.setVisible(visible) - self.label.setVisible(visible) + self.text_label.setVisible(visible) def _build_container(self): - if self.label is not None: + if self.text_label is not None: add_widgets( self.container.layout, [ @@ -1082,8 +1085,8 @@ def __init__( self.layout = None - if text_label is not None: - self.label = make_label(name=text_label) + if label is not None: + self.label = make_label(name=label) self.valueChanged.connect(self._update_step) def _update_step(self): diff --git a/requirements.txt b/requirements.txt index 3ca0e56d..e7a321f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,8 +14,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pre-commit -pyclesperanto-prototype>=0.22.0 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 ruff From db4babfc2113420733eb583a2e22205d103729f8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 436/577] Disabled small removal in Voronoi-Otsu --- napari_cellseg3d/code_models/instance_segmentation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index f5066ebe..ff99409a 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -159,15 +159,13 @@ def voronoi_otsu( spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation + Returns: Instance segmentation labels from Voronoi-Otsu method """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - # semantic = np.squeeze(volume) - logger.debug( - f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" - ) + semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -491,7 +489,6 @@ def __init__(self, widget_parent=None): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ From 517ae33abecbcb5eb96c65849270c8cb67353916 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 437/577] Added new docs for instance seg --- .../code_models/instance_segmentation.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index ff99409a..5aded9bf 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -38,14 +38,11 @@ def __init__( ): """ Methods for instance segmentation - Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function - widget_parent: parent for the declared widgets - """ self.name = name self.function = function @@ -403,10 +400,10 @@ def sphericity(region): ) -class Watershed(InstanceMethod): +class Watershed(InstanceMethod, metaclass=Singleton): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent=None): + def __init__(self): super().__init__( name=WATERSHED, function=binary_watershed, @@ -449,10 +446,10 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod): +class ConnectedComponents(InstanceMethod, metaclass=Singleton): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent=None): + def __init__(self): super().__init__( name=CONNECTED_COMP, function=binary_connected, @@ -480,10 +477,10 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod): +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent=None): + def __init__(self): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, @@ -604,7 +601,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VORONOI_OTSU: VoronoiOtsu, - WATERSHED: Watershed, - CONNECTED_COMP: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, } From 45ce3213b97538816dff4643f36e9aefc41f6089 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 09:50:45 +0100 Subject: [PATCH 438/577] Docs + UI update - Updated welcome/README - Changed step for DoubleCounter --- README.md | 2 +- docs/res/welcome.rst | 4 +++- napari_cellseg3d/interface.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ece6c6f4..ca8d0931 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 892549a8..12a20630 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -103,6 +103,8 @@ This plugin mainly uses the following libraries and software: * `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase +* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase + * A custom re-implementation of the `WNet model`_ by Xia and Kulis [#]_ .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ @@ -113,7 +115,7 @@ This plugin mainly uses the following libraries and software: .. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets .. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype -.. _WNet model: https://arxiv.org/abs/1711.08506 + .. rubric:: References diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 61764c84..ec131ce0 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1091,9 +1091,9 @@ def __init__( def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.1) + self.setSingleStep(0.01) else: - self.setSingleStep(1) + self.setSingleStep(0.1) @property def tooltips(self): From 87810fa3d8a2d735cc215d3fe5e2ed94e10be1ac Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:07:33 +0100 Subject: [PATCH 439/577] Update requirements.txt Fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e7a321f2..8607ae90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pyclesperanto-prototype>=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 ruff From 5beafd969e8d946364dbb8e0ae95526242a1b851 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 440/577] isort --- .../code_models/instance_segmentation.py | 18 +++++++++--------- .../code_plugins/plugin_convert.py | 10 ++-------- .../code_plugins/plugin_model_inference.py | 8 ++++---- napari_cellseg3d/config.py | 11 +++++------ napari_cellseg3d/interface.py | 3 ++- 5 files changed, 22 insertions(+), 28 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 5aded9bf..b7cbcba9 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -1,23 +1,23 @@ -import abc +from __future__ import division +from __future__ import print_function from dataclasses import dataclass from functools import partial from typing import List - import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label, regionprops +from skimage.measure import label +from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread - -from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis - -# from skimage.measure import marching_cubes # from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes +from napari_cellseg3d import interface as ui +from napari_cellseg3d.utils import fill_list_in_between +from napari_cellseg3d.utils import Singleton +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index b96727d0..23cc581d 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,5 +1,4 @@ from pathlib import Path - import napari import numpy as np from qtpy.QtWidgets import QSizePolicy @@ -7,15 +6,10 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.instance_segmentation import ( - InstanceWidgets, - clear_small_objects, - threshold, - to_semantic, -) +from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder MAX_W = ui.UTILS_MAX_WIDTH diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 33e15543..5d7c3a23 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -15,13 +15,13 @@ InstanceWidgets, ) from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import ( INSTANCE_SEGMENTATION_METHOD_LIST, ) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index c082407e..d07dd1ab 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,6 +6,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -14,12 +19,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index ec131ce0..09269497 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -11,7 +11,8 @@ # from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QObject from qtpy.QtCore import Qt -# from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices From 334cae015ae47808e4080b0ec1dbef89107fb070 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 441/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 45c3be1c..814cca2e 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 +qt_api=pyqt5 \ No newline at end of file From fa01b7de5c9f8dc2cc08da7fcd54ad0b14478069 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 442/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- .../code_models/instance_segmentation.py | 51 ++++++++----------- .../code_plugins/plugin_model_inference.py | 1 - 2 files changed, 21 insertions(+), 31 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index b7cbcba9..a6bbb7f1 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -16,8 +16,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from napari_cellseg3d.utils import sphericity_volume_area @@ -34,7 +34,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None, + widget_parent: QWidget = None ): """ Methods for instance segmentation @@ -43,6 +43,7 @@ def __init__( function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + widget_parent: parent for the declared widgets """ self.name = name self.function = function @@ -54,14 +55,7 @@ def __init__( setattr( self, widget, - ui.Slider( - 0, - 100, - 1, - divide_factor=100, - text_label="", - parent=widget_parent, - ), + ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), ) self.sliders.append(getattr(self, widget)) @@ -71,9 +65,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter( - text_label="", parent=widget_parent - ), + ui.DoubleIncrementCounter(label="", parent=None), ) self.counters.append(getattr(self, widget)) @@ -400,16 +392,16 @@ def sphericity(region): ) -class Watershed(InstanceMethod, metaclass=Singleton): +class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent, + widget_parent=widget_parent ) self.sliders[0].label.setText("Foreground probability threshold") @@ -446,16 +438,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): +class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent, + widget_parent=widget_parent ) self.sliders[0].label.setText("Foreground probability threshold") @@ -477,15 +469,16 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): +class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self): + def __init__(self, widget_parent): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, num_counters=2, + widget_parent=widget_parent ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ @@ -557,9 +550,8 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) - self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets ? + # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -570,15 +562,14 @@ def _build(self): group.layout.addWidget(counter) self.instance_widgets[name].append(counter) except RuntimeError as e: - logger.debug( - f"Caught runtime error {e}, most likely during testing" - ) + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets: + + for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) @@ -601,7 +592,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VoronoiOtsu().name: VoronoiOtsu, - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 5d7c3a23..cdb770fd 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -203,7 +203,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## # instance segmentation widgets self.instance_widgets = InstanceWidgets(parent=self) - self.crf_widgets = CRFParamsWidget(parent=self) self.use_instance_choice = ui.CheckBox( "Run instance segmentation", From 5684d8f5c2de31161638f9acf065c31e51a62859 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 443/577] Fix inference --- .../code_models/instance_segmentation.py | 1 + napari_cellseg3d/code_models/workers.py | 6 +- .../code_plugins/plugin_model_inference.py | 23 +- napari_cellseg3d/config.py | 7 +- notebooks/assess_instance.ipynb | 609 +----------------- 5 files changed, 36 insertions(+), 610 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index a6bbb7f1..9ecbf8b4 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -550,6 +550,7 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class self.instance_widgets[name] = [] # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c588ce14..0baa6373 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -617,6 +617,8 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): semantic_labels, i + 1, ) + if from_layer: + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -677,8 +679,8 @@ def instance_seg( if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index cdb770fd..c943fb52 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -605,9 +605,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -873,11 +874,13 @@ def on_yield(self, result: InferenceResult): np.unique(result.instance_labels.flatten()).size - 1 ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + if result.instance_labels is not None: + labels = result.instance_labels + method_name = self.worker_config.post_process_config.instance.method.name viewer.add_labels(result.instance_labels, name=name) - from napari_cellseg3d.utils import LOGGER as log + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" if result.stats is not None and isinstance( result.stats, list @@ -898,11 +901,11 @@ def on_yield(self, result: InferenceResult): f"Number of instances in channel {i} : {stats.number_objects[0]}" ) - csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index d07dd1ab..15e48f6e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -123,6 +123,11 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None + @dataclass class PostProcessConfig: """Class to record params for post processing @@ -135,7 +140,7 @@ class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() @dataclass diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 3dae22a9..40412282 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,632 +4,47 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "pycharm": { - "is_executing": true - }, - "tags": [] + "collapsed": true }, "outputs": [], "source": [ - "import napari\n", "import numpy as np\n", - "from pathlib import Path\n", "from tifffile import imread\n", - "\n", - "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", - "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import (\n", - " binary_connected,\n", - " binary_watershed,\n", - " voronoi_otsu,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "pycharm": { - "is_executing": true - }, - "tags": [] - }, - "outputs": [], - "source": [ - "viewer = napari.Viewer()" + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "# prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", - "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", - "\n", - "prediction = imread(prediction_path)\n", - "gt_labels = imread(gt_labels_path)\n", - "\n", - "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", - "# prediction_resized = prediction # for trailmap\n", - "gt_labels_resized = resize(gt_labels, zoom)\n", - "\n", - "\n", - "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "0.8592223181276479" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from napari_cellseg3d.utils import dice_coeff\n", - "\n", - "semantic_gt = to_semantic(gt_labels_resized.copy())\n", - "semantic_pred = to_semantic(prediction_resized.copy())\n", - "\n", - "viewer.add_image(semantic_gt, colormap='bop blue')\n", - "viewer.add_image(semantic_pred, colormap='red')\n", - "\n", - "dice_coeff(\n", - " semantic_gt,\n", - " prediction_resized\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, + "execution_count": null, "outputs": [], - "source": [ - "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", - "\n", - "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n", - "125\n" - ] - }, - { - "data": { - "text/plain": [ - "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "connected = binary_connected(prediction_resized, thres_small=2)\n", - "viewer.add_labels(connected, name=\"connected\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:37:19,775 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3699.66it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:37:19,812 - Calculating the number of neurons not found...\n", - "2023-03-31 15:37:19,815 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-31 15:37:19,816 - Percent of fused neurons found: 36.80%\n", - "2023-03-31 15:37:19,817 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, connected)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, + "source": [], "metadata": { "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:37:19,919 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3992.79it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:37:19,949 - Calculating the number of neurons not found...\n", - "2023-03-31 15:37:19,952 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-31 15:37:19,953 - Percent of fused neurons found: 34.40%\n", - "2023-03-31 15:37:19,953 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(68,\n", - " 43,\n", - " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", - ")\n", - "viewer.add_labels(watershed)\n", - "eval.evaluate_model_performance(gt_labels_resized, watershed)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:37:21,076 - build program: kernel 'gaussian_blur_separable_3d' was part of a lengthy source build resulting from a binary cache miss (0.88 s)\n", - "2023-03-31 15:37:21,514 - build program: kernel 'copy_3d' was part of a lengthy source build resulting from a binary cache miss (0.42 s)\n", - "2023-03-31 15:37:22,021 - build program: kernel 'detect_maxima_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", - "2023-03-31 15:37:22,642 - build program: kernel 'minimum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.59 s)\n", - "2023-03-31 15:37:23,117 - build program: kernel 'minimum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", - "2023-03-31 15:37:23,651 - build program: kernel 'minimum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", - "2023-03-31 15:37:24,188 - build program: kernel 'maximum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", - "2023-03-31 15:37:24,801 - build program: kernel 'maximum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.60 s)\n", - "2023-03-31 15:37:25,263 - build program: kernel 'maximum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", - "2023-03-31 15:37:25,766 - build program: kernel 'histogram_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", - "2023-03-31 15:37:26,256 - build program: kernel 'sum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", - "2023-03-31 15:37:26,699 - build program: kernel 'greater_constant_3d' was part of a lengthy source build resulting from a binary cache miss (0.43 s)\n", - "2023-03-31 15:37:27,158 - build program: kernel 'binary_and_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", - "2023-03-31 15:37:27,635 - build program: kernel 'add_image_and_scalar_3d' was part of a lengthy source build resulting from a binary cache miss (0.47 s)\n", - "2023-03-31 15:37:28,128 - build program: kernel 'set_nonzero_pixels_to_pixelindex' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", - "2023-03-31 15:37:28,580 - build program: kernel 'set_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", - "2023-03-31 15:37:29,076 - build program: kernel 'nonzero_minimum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", - "2023-03-31 15:37:29,551 - build program: kernel 'set_2d' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", - "2023-03-31 15:37:30,035 - build program: kernel 'flag_existing_labels' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", - "2023-03-31 15:37:30,544 - build program: kernel 'set_column_2d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", - "2023-03-31 15:37:31,033 - build program: kernel 'sum_reduction_x' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", - "2023-03-31 15:37:31,572 - build program: kernel 'block_enumerate' was part of a lengthy source build resulting from a binary cache miss (0.53 s)\n", - "2023-03-31 15:37:32,094 - build program: kernel 'replace_intensities' was part of a lengthy source build resulting from a binary cache miss (0.51 s)\n", - "2023-03-31 15:37:32,685 - build program: kernel 'add_images_weighted_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", - "2023-03-31 15:37:33,256 - build program: kernel 'onlyzero_overwrite_maximum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.56 s)\n", - "2023-03-31 15:37:33,845 - build program: kernel 'onlyzero_overwrite_maximum_diamond_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", - "2023-03-31 15:37:34,369 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", - "2023-03-31 15:37:34,888 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n" - ] - }, - { - "data": { - "text/plain": [ - "(25, 64, 64)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voronoi = voronoi_otsu(prediction_resized, 0.6, outline_sigma=0.7)\n", - "\n", - "from skimage.morphology import remove_small_objects\n", - "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(25, 64, 64)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "is_executing": true - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(voronoi, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(gt_labels_resized, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:37:36,854 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 611.96it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:37:37,087 - Calculating the number of neurons not found...\n", - "2023-03-31 15:37:37,098 - Percent of non-fused neurons found: 87.20%\n", - "2023-03-31 15:37:37,104 - Percent of fused neurons found: 1.60%\n", - "2023-03-31 15:37:37,114 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(109,\n", - " 2,\n", - " 13,\n", - " 8,\n", - " 0.8285521200005869,\n", - " 0.8809251900364068,\n", - " 0.9838709677419355,\n", - " 0.782258064516129,\n", - " 1.0)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, "pycharm": { - "is_executing": true + "name": "#%%\n" } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-31 15:40:34,683 - No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'\n" - ] - } - ], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] + } } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.16" + "pygments_lexer": "ipython2", + "version": "2.7.6" } }, "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat_minor": 0 +} \ No newline at end of file From 16766b85936c2da71da3423841f43c0254164db4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 444/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/instance_segmentation.py | 3 +- napari_cellseg3d/code_plugins/plugin_crop.py | 7 +- .../dev_scripts/artefact_labeling.py | 143 +++---- .../dev_scripts/correct_labels.py | 125 ++---- .../dev_scripts/evaluate_labels.py | 405 ++++-------------- notebooks/assess_instance.ipynb | 401 ++++++++++++++++- 6 files changed, 584 insertions(+), 500 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 9ecbf8b4..dc637159 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -38,12 +38,14 @@ def __init__( ): """ Methods for instance segmentation + Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -148,7 +150,6 @@ def voronoi_otsu( spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - Returns: Instance segmentation labels from Voronoi-Otsu method diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 74691e1f..5b09ad3f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -194,12 +194,7 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable( - layout, - self, - max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT], - min_wh=[200, 200], - ) + ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 93746eb6..875ca9b6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,17 +1,15 @@ -import os # TODO(cyril): remove os -from pathlib import Path - -import napari import numpy as np +from tifffile import imread +from tifffile import imwrite +from pathlib import Path import scipy.ndimage as ndimage -from skimage.filters import threshold_otsu -from tifffile import imread, imwrite - -from napari_cellseg3d.code_models.instance_segmentation import binary_watershed - +import os +import napari # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from skimage.filters import threshold_otsu """ New code by Yves Paychere @@ -46,9 +44,7 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append( - np.array([i, unique[np.argmax(counts)]]) - ) + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -65,7 +61,7 @@ def map_labels(labels, artefacts): def make_labels( - image, + path_image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -77,7 +73,7 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - image : str + path_image : str Path to image. path_labels_out : str Path of the output labelled image. @@ -97,25 +93,21 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - # image = imread(image) + image = imread(path_image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( - np.max(image_contrasted) - np.min(image_contrasted) - ) + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size( - labels, min_size=threshold_size, is_labeled=True - ) + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -127,26 +119,26 @@ def make_labels( ) -def select_image_by_labels(image, labels, path_image_out, label_values): +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - image : np.array - image. - labels : np.array - labels. + path_image : str + Path to image. + path_labels : str + Path to labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - # image = imread(image) - # labels = imread(labels) + image = imread(path_image) + labels = imread(path_labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smallest cube that contains all the non-zero pixels of a 3d image +# select the smalles cube that contains all the none zero pixel of an 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -164,15 +156,16 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(image, path_image_out): +def crop_image_path(path_image, path_image_out): """Crop image. Parameters ---------- - image : np.array - image + path_image : str + Path to image. path_image_out : str Path of the output image. """ + image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -220,9 +213,7 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile( - image[neurons], threshold_artefact_brightness_percent - ) + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -253,9 +244,7 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile( - sizes, threshold_artefact_size_percent - ) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -290,18 +279,23 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): ndarray Label image with artefacts labelled and small artefacts removed. """ - labels = ndimage.label(artefacts)[0] if not is_labeled else artefacts + if not is_labeled: + # find all the connected components in the artefacts image + labels = ndimage.label(artefacts)[0] + else: + labels = artefacts # remove the small components labels_i, counts = np.unique(labels, return_counts=True) labels_i = labels_i[counts > min_size] labels_i = labels_i[labels_i > 0] - return np.where(np.isin(labels, labels_i), labels, 0) + artefacts = np.where(np.isin(labels, labels_i), labels, 0) + return artefacts def create_artefact_labels( - image, - labels, + image_path, + labels_path, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -310,10 +304,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image : np.array - image for artefact detection. - labels : np.array - label image array with each neurons labelled as a different int value. + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -323,6 +317,9 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ + image = imread(image_path) + labels = imread(labels_path) + artefacts = make_artefact_labels( image, labels, @@ -342,12 +339,11 @@ def visualize_images(paths): Parameters ---------- paths : list - List of images to visualize. + List of paths to images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - image = imread(path) - viewer.add_image(image) + viewer.add_image(imread(path), name=os.path.basename(path)) # wait for the user to close the viewer napari.run() @@ -374,17 +370,13 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [ - f for f in os.listdir(path + "/labels") if f.endswith(".tif") - ] - path_images = [ - f for f in os.listdir(path + "/volumes") if f.endswith(".tif") - ] + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] # sort the list path_labels.sort() path_images.sort() # create the output folder - Path().mkdir(path + "/artefact_neurons", exist_ok=True) + os.makedirs(path + "/artefact_neurons", exist_ok=True) # create the artefact labels for i in range(len(path_images)): print(path_labels[i]) @@ -407,22 +399,23 @@ def create_artefact_labels_from_folder( ) -# if __name__ == "__main__": -# repo_path = Path(__file__).resolve().parents[1] -# print(f"REPO PATH : {repo_path}") -# paths = [ -# "dataset_clean/cropped_visual/train", -# "dataset_clean/cropped_visual/val", -# "dataset_clean/somatomotor", -# "dataset_clean/visual_tif", -# ] -# for data_path in paths: -# path = str(repo_path / data_path) -# print(path) -# create_artefact_labels_from_folder( -# path, -# do_visualize=False, -# threshold_artefact_brightness_percent=20, -# threshold_artefact_size_percent=1, -# contrast_power=20, -# ) +if __name__ == "__main__": + + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f413812d..f94327e2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,23 +1,19 @@ -import threading -import time -import warnings -from functools import partial -from pathlib import Path - -import napari import numpy as np +from tifffile import imread +from tifffile import imwrite import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings from napari.qt.threading import thread_worker -from tifffile import imread, imwrite from tqdm import tqdm - -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.instance_segmentation import binary_watershed - +import threading # import sys # sys.path.append(str(Path(__file__) / "../../")) - +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -37,9 +33,7 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm( - range(len(unique_label)), desc="relabeling", ncols=100 - ): + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): i = unique_label[i_label] if i == 0: continue @@ -87,16 +81,13 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact, test=False): +def ask_labels(unique_artefact): global returns returns = [] - if not test: - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] - else: - i_labels_to_add_tmp = [0] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] if i_labels_to_add_tmp == [0]: print("no label added") @@ -139,15 +130,7 @@ def ask_labels(unique_artefact, test=False): print("close the napari window to continue") -def relabel( - image_path, - label_path, - go_fast=False, - check_for_unicity=True, - delay=0.3, - viewer=None, - test=False, -): +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -161,8 +144,6 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 - viewer : napari.Viewer, optional - the napari viewer, by default None """ global returns @@ -177,10 +158,7 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - if not test: - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -200,49 +178,30 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread( - target=partial(ask_labels, test=test), args=(unique_artefact,) - ) + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add), 0, artefact - ) - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer - viewer.add_image(image, name="image") + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - if not test: - napari.run() + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add_tmp), artefact, 0 - ) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) print("these labels will be added") - if test: - viewer.close() - viewer = napari.view_image(image) if viewer is None else viewer - if not test: - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") - if test: - revert = "n" - viewer.close() + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) - if test: - break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -299,16 +258,12 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget( - old_label, new_label, map_labels_existing, delay=0.5 -): +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect( - lambda arg: modify_viewer(old_label, new_label, arg) - ) + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -325,12 +280,8 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] - ) - new_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] - ) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -339,9 +290,7 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget( - old_label, new_label, map_labels_existing, delay=delay - ) + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) napari.run() @@ -358,14 +307,14 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, - str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) ) if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/somatomotor") - image_path = str(im_path / "volumes/c1images.tif") - gt_labels_path = str(im_path / "labels/c1labels.tif") - relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 64fbaf5e..857bcd19 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,20 +1,74 @@ -import napari import numpy as np import pandas as pd from tqdm import tqdm +import napari from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance( - labels, - model_labels, - threshold_correct=PERCENT_CORRECT, - print_details=False, - visualize=False, -): +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): """Evaluate the model performance. Parameters ---------- @@ -24,8 +78,6 @@ def evaluate_model_performance( Label image from the model labelled as mulitple values. do_print : bool If True, print the results. - visualize : bool - If True, visualize the results. Returns ------- neuron_found : float @@ -67,9 +119,7 @@ def evaluate_model_performance( artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean( - [i[3] for i in map_labels_existing] - ) + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -78,9 +128,7 @@ def evaluate_model_performance( if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean( - [i[2] for i in map_fused_neurons] - ) + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -95,42 +143,29 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info( - f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" - ) - - if print_details: - log.info(f"Neurons found: {neurons_found}") - log.info(f"Neurons fused: {neurons_fused}") - log.info(f"Neurons not found: {neurons_not_found}") - log.info(f"Artefacts found: {artefacts_found}") - log.info( - "Mean true positive ratio of the model: ", - ) - log.info(mean_true_positive_ratio_model) - log.info( + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, ) - log.info(mean_ratio_pixel_found) - log.info( + print( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, ) - log.info(mean_ratio_pixel_found_fused) - log.info( + print( "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, + ) + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact ) - log.info(mean_true_positive_ratio_model_fused) - log.info("Mean ratio of false pixel in artefacts: ") - log.info(mean_ratio_false_pixel_artefact) - if visualize: - viewer = napari.Viewer(ndisplay=3) + viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") viewer.add_labels(model_labels, name="model's labels") found_model = np.where( @@ -144,21 +179,15 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) is False, - unique_labels, - 0, + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where( - np.isin(labels, neurones_not_found_labels), labels, 0 - ) + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), - model_labels, - 0, + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -186,81 +215,6 @@ def evaluate_model_performance( ) -def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > threshold_correct: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > threshold_correct: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels - - def save_as_csv(results, path): """ Save the results as a csv file @@ -272,7 +226,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - log.debug(np.array(results).shape) + print(np.array(results).shape) df = pd.DataFrame( [results], columns=[ @@ -290,193 +244,6 @@ def save_as_csv(results, path): df.to_csv(path, index=False) -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From 984b212b83ef4d9faac15045ec2beb9b71167596 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 445/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../dev_scripts/evaluate_labels.py | 22 +- notebooks/assess_instance.ipynb | 408 ++++++++++++------ 2 files changed, 301 insertions(+), 129 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..b4436ccb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -4,6 +4,7 @@ import napari from napari_cellseg3d.utils import LOGGER as log + def map_labels(labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -33,10 +34,12 @@ def map_labels(labels, model_labels): unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + print(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + print(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -50,8 +53,7 @@ def map_labels(labels, model_labels): tmp_map.append( [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,12 +61,14 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # print(f"map_labels_existing: {map_labels_existing}") + print(f"map_fused_neurons: {map_fused_neurons}") + # print(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels @@ -99,7 +103,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -109,7 +113,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..6e6a9b5f 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -111,17 +111,274 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "i: 1\n", + "unique: 1\n", + "i: 2\n", + "unique: 2\n", + "i: 3\n", + "unique: 3\n", + "i: 4\n", + "unique: 4\n", + "i: 5\n", + "unique: 5\n", + "i: 6\n", + "unique: 6\n", + "i: 7\n", + "unique: 7\n", + "i: 8\n", + "unique: 8\n", + "i: 9\n", + "unique: 9\n", + "i: 10\n", + "unique: 10\n", + "i: 11\n", + "unique: 11\n", + "i: 12\n", + "unique: 12\n", + "i: 13\n", + "unique: 13\n", + "i: 14\n", + "unique: 14\n", + "i: 15\n", + "unique: 15\n", + "i: 16\n", + "unique: 16\n", + "i: 17\n", + "unique: 17\n", + "i: 18\n", + "unique: 18\n", + "i: 19\n", + "unique: 19\n", + "i: 20\n", + "unique: 20\n", + "i: 21\n", + "unique: 21\n", + "i: 22\n", + "unique: 22\n", + "i: 23\n", + "unique: 23\n", + "i: 24\n", + "unique: 24\n", + "i: 25\n", + "unique: 25\n", + "i: 26\n", + "unique: 26\n", + "i: 27\n", + "unique: 27\n", + "i: 28\n", + "unique: 28\n", + "i: 29\n", + "unique: 29\n", + "i: 30\n", + "unique: 30\n", + "i: 31\n", + "unique: 31\n", + "i: 32\n", + "unique: 32\n", + "i: 33\n", + "unique: 33\n", + "i: 34\n", + "unique: 34\n", + "i: 35\n", + "unique: 35\n", + "i: 36\n", + "unique: 36\n", + "i: 37\n", + "unique: 37\n", + "i: 38\n", + "unique: 38\n", + "i: 39\n", + "unique: 39\n", + "i: 40\n", + "unique: 40\n", + "i: 41\n", + "unique: 41\n", + "i: 42\n", + "unique: 42\n", + "i: 43\n", + "unique: 43\n", + "i: 44\n", + "unique: 44\n", + "i: 45\n", + "unique: 45\n", + "i: 46\n", + "unique: 46\n", + "i: 47\n", + "unique: 47\n", + "i: 48\n", + "unique: 48\n", + "i: 49\n", + "unique: 49\n", + "i: 50\n", + "unique: 50\n", + "i: 51\n", + "unique: 51\n", + "i: 52\n", + "unique: 52\n", + "i: 53\n", + "unique: 53\n", + "i: 54\n", + "unique: 54\n", + "i: 55\n", + "unique: 55\n", + "i: 56\n", + "unique: 56\n", + "i: 57\n", + "unique: 57\n", + "i: 58\n", + "unique: 58\n", + "i: 59\n", + "unique: 59\n", + "i: 60\n", + "unique: 60\n", + "i: 61\n", + "unique: 61\n", + "i: 62\n", + "unique: 62\n", + "i: 63\n", + "unique: 63\n", + "i: 64\n", + "unique: 64\n", + "i: 65\n", + "unique: 65\n", + "i: 66\n", + "unique: 66\n", + "i: 67\n", + "unique: 67\n", + "i: 68\n", + "unique: 68\n", + "i: 69\n", + "unique: 69\n", + "i: 70\n", + "unique: 70\n", + "i: 71\n", + "unique: 71\n", + "i: 72\n", + "unique: 72\n", + "i: 73\n", + "unique: 73\n", + "i: 74\n", + "unique: 74\n", + "i: 75\n", + "unique: 75\n", + "i: 76\n", + "unique: 76\n", + "i: 77\n", + "unique: 77\n", + "i: 78\n", + "unique: 78\n", + "i: 79\n", + "unique: 79\n", + "i: 80\n", + "unique: 80\n", + "i: 81\n", + "unique: 81\n", + "i: 82\n", + "unique: 82\n", + "i: 83\n", + "unique: 83\n", + "i: 84\n", + "unique: 84\n", + "i: 85\n", + "unique: 85\n", + "i: 86\n", + "unique: 86\n", + "i: 87\n", + "unique: 87\n", + "i: 88\n", + "unique: 88\n", + "i: 89\n", + "unique: 89\n", + "i: 90\n", + "unique: 90\n", + "i: 91\n", + "unique: 91\n", + "i: 93\n", + "unique: 93\n", + "i: 94\n", + "unique: 94\n", + "i: 95\n", + "unique: 95\n", + "i: 96\n", + "unique: 96\n", + "i: 97\n", + "unique: 97\n", + "i: 98\n", + "unique: 98\n", + "i: 99\n", + "unique: 99\n", + "i: 100\n", + "unique: 100\n", + "i: 101\n", + "unique: 101\n", + "i: 102\n", + "unique: 102\n", + "i: 103\n", + "unique: 103\n", + "i: 104\n", + "unique: 104\n", + "i: 105\n", + "unique: 105\n", + "i: 106\n", + "unique: 106\n", + "i: 107\n", + "unique: 107\n", + "i: 108\n", + "unique: 108\n", + "i: 109\n", + "unique: 109\n", + "i: 110\n", + "unique: 110\n", + "i: 111\n", + "unique: 111\n", + "i: 112\n", + "unique: 112\n", + "i: 113\n", + "unique: 113\n", + "i: 114\n", + "unique: 114\n", + "i: 115\n", + "unique: 115\n", + "i: 116\n", + "unique: 116\n", + "i: 117\n", + "unique: 117\n", + "i: 118\n", + "unique: 118\n", + "i: 119\n", + "unique: 119\n", + "i: 120\n", + "unique: 120\n", + "i: 121\n", + "unique: 121\n", + "i: 122\n", + "unique: 122\n", + "i: 123\n", + "unique: 123\n", + "i: 124\n", + "unique: 124\n", + "i: 125\n", + "unique: 125\n", + "map_fused_neurons: []\n", + "Calculating the number of neurons not found...\n", "Neurons found: 124\n", "Neurons fused: 0\n", "Neurons not found: 0\n", @@ -157,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -168,145 +425,66 @@ { "data": { "text/plain": [ - "" + "dtype('int32')" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected,name='connected')\n", + "connected.dtype" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -320,7 +498,7 @@ { "data": { "text/plain": [ - "(25, 64, 64)" + "dtype('int64')" ] }, "execution_count": 9, @@ -329,14 +507,12 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" + "gt_labels_resized.dtype" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -353,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -374,15 +550,7 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" - ] - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] From 7bba6277d7d6f0776c66e33b3cbb99cc3607558b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 446/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 +- .../dev_scripts/evaluate_labels.py | 282 +++++++-- notebooks/assess_instance.ipynb | 553 ++++++++---------- requirements.txt | 2 +- setup.cfg | 2 +- 6 files changed, 565 insertions(+), 352 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b4436ccb..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,15 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -22,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -29,17 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 - print(f"i: {i}") + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - print(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -47,14 +228,20 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if len(tmp_map) == 1: # map to only one true neuron -> found neuron map_labels_existing.append(tmp_map[0]) @@ -62,17 +249,21 @@ def map_labels(labels, model_labels): # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map - # print(f"map_labels_existing: {map_labels_existing}") - print(f"map_fused_neurons: {map_fused_neurons}") - # print(f"new_labels: {new_labels}") + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -82,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -103,7 +296,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -113,7 +306,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) @@ -123,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -132,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -148,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -183,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -230,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 6e6a9b5f..d521c395 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -115,279 +110,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "i: 1\n", - "unique: 1\n", - "i: 2\n", - "unique: 2\n", - "i: 3\n", - "unique: 3\n", - "i: 4\n", - "unique: 4\n", - "i: 5\n", - "unique: 5\n", - "i: 6\n", - "unique: 6\n", - "i: 7\n", - "unique: 7\n", - "i: 8\n", - "unique: 8\n", - "i: 9\n", - "unique: 9\n", - "i: 10\n", - "unique: 10\n", - "i: 11\n", - "unique: 11\n", - "i: 12\n", - "unique: 12\n", - "i: 13\n", - "unique: 13\n", - "i: 14\n", - "unique: 14\n", - "i: 15\n", - "unique: 15\n", - "i: 16\n", - "unique: 16\n", - "i: 17\n", - "unique: 17\n", - "i: 18\n", - "unique: 18\n", - "i: 19\n", - "unique: 19\n", - "i: 20\n", - "unique: 20\n", - "i: 21\n", - "unique: 21\n", - "i: 22\n", - "unique: 22\n", - "i: 23\n", - "unique: 23\n", - "i: 24\n", - "unique: 24\n", - "i: 25\n", - "unique: 25\n", - "i: 26\n", - "unique: 26\n", - "i: 27\n", - "unique: 27\n", - "i: 28\n", - "unique: 28\n", - "i: 29\n", - "unique: 29\n", - "i: 30\n", - "unique: 30\n", - "i: 31\n", - "unique: 31\n", - "i: 32\n", - "unique: 32\n", - "i: 33\n", - "unique: 33\n", - "i: 34\n", - "unique: 34\n", - "i: 35\n", - "unique: 35\n", - "i: 36\n", - "unique: 36\n", - "i: 37\n", - "unique: 37\n", - "i: 38\n", - "unique: 38\n", - "i: 39\n", - "unique: 39\n", - "i: 40\n", - "unique: 40\n", - "i: 41\n", - "unique: 41\n", - "i: 42\n", - "unique: 42\n", - "i: 43\n", - "unique: 43\n", - "i: 44\n", - "unique: 44\n", - "i: 45\n", - "unique: 45\n", - "i: 46\n", - "unique: 46\n", - "i: 47\n", - "unique: 47\n", - "i: 48\n", - "unique: 48\n", - "i: 49\n", - "unique: 49\n", - "i: 50\n", - "unique: 50\n", - "i: 51\n", - "unique: 51\n", - "i: 52\n", - "unique: 52\n", - "i: 53\n", - "unique: 53\n", - "i: 54\n", - "unique: 54\n", - "i: 55\n", - "unique: 55\n", - "i: 56\n", - "unique: 56\n", - "i: 57\n", - "unique: 57\n", - "i: 58\n", - "unique: 58\n", - "i: 59\n", - "unique: 59\n", - "i: 60\n", - "unique: 60\n", - "i: 61\n", - "unique: 61\n", - "i: 62\n", - "unique: 62\n", - "i: 63\n", - "unique: 63\n", - "i: 64\n", - "unique: 64\n", - "i: 65\n", - "unique: 65\n", - "i: 66\n", - "unique: 66\n", - "i: 67\n", - "unique: 67\n", - "i: 68\n", - "unique: 68\n", - "i: 69\n", - "unique: 69\n", - "i: 70\n", - "unique: 70\n", - "i: 71\n", - "unique: 71\n", - "i: 72\n", - "unique: 72\n", - "i: 73\n", - "unique: 73\n", - "i: 74\n", - "unique: 74\n", - "i: 75\n", - "unique: 75\n", - "i: 76\n", - "unique: 76\n", - "i: 77\n", - "unique: 77\n", - "i: 78\n", - "unique: 78\n", - "i: 79\n", - "unique: 79\n", - "i: 80\n", - "unique: 80\n", - "i: 81\n", - "unique: 81\n", - "i: 82\n", - "unique: 82\n", - "i: 83\n", - "unique: 83\n", - "i: 84\n", - "unique: 84\n", - "i: 85\n", - "unique: 85\n", - "i: 86\n", - "unique: 86\n", - "i: 87\n", - "unique: 87\n", - "i: 88\n", - "unique: 88\n", - "i: 89\n", - "unique: 89\n", - "i: 90\n", - "unique: 90\n", - "i: 91\n", - "unique: 91\n", - "i: 93\n", - "unique: 93\n", - "i: 94\n", - "unique: 94\n", - "i: 95\n", - "unique: 95\n", - "i: 96\n", - "unique: 96\n", - "i: 97\n", - "unique: 97\n", - "i: 98\n", - "unique: 98\n", - "i: 99\n", - "unique: 99\n", - "i: 100\n", - "unique: 100\n", - "i: 101\n", - "unique: 101\n", - "i: 102\n", - "unique: 102\n", - "i: 103\n", - "unique: 103\n", - "i: 104\n", - "unique: 104\n", - "i: 105\n", - "unique: 105\n", - "i: 106\n", - "unique: 106\n", - "i: 107\n", - "unique: 107\n", - "i: 108\n", - "unique: 108\n", - "i: 109\n", - "unique: 109\n", - "i: 110\n", - "unique: 110\n", - "i: 111\n", - "unique: 111\n", - "i: 112\n", - "unique: 112\n", - "i: 113\n", - "unique: 113\n", - "i: 114\n", - "unique: 114\n", - "i: 115\n", - "unique: 115\n", - "i: 116\n", - "unique: 116\n", - "i: 117\n", - "unique: 117\n", - "i: 118\n", - "unique: 118\n", - "i: 119\n", - "unique: 119\n", - "i: 120\n", - "unique: 120\n", - "i: 121\n", - "unique: 121\n", - "i: 122\n", - "unique: 122\n", - "i: 123\n", - "unique: 123\n", - "i: 124\n", - "unique: 124\n", - "i: 125\n", - "unique: 125\n", - "map_fused_neurons: []\n", - "Calculating the number of neurons not found...\n", - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -414,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -428,66 +165,177 @@ "dtype('int32')" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')\n", + "viewer.add_labels(connected, name=\"connected\")\n", "connected.dtype" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -501,7 +349,7 @@ "dtype('int64')" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -512,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -522,42 +370,155 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { diff --git a/requirements.txt b/requirements.txt index 8607ae90..92aae176 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ matplotlib>=3.4.1 ruff tifffile>=2022.2.9 torch>=1.11 -monai[nibabel,einops,tifffile]>=1.0.1 +monai[nibabel,einops]>=1.0.1 pillow scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index 8ee82f96..9045eec4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops,tifffile]>=1.0.1 + monai[nibabel,einops]>=1.0.1 itk tqdm nibabel From 1cd7f0e320071f6b99280edbf9533e451d564965 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 447/577] black --- .../code_models/instance_segmentation.py | 21 ++++++++---- napari_cellseg3d/code_models/workers.py | 4 ++- .../code_plugins/plugin_model_inference.py | 8 +++-- napari_cellseg3d/config.py | 2 ++ .../dev_scripts/evaluate_labels.py | 33 +++++++++++-------- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index dc637159..5881966d 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -34,7 +34,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None + widget_parent: QWidget = None, ): """ Methods for instance segmentation @@ -57,7 +57,14 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), + ui.Slider( + 0, + 100, + 1, + divide_factor=100, + text_label="", + parent=None, + ), ) self.sliders.append(getattr(self, widget)) @@ -396,13 +403,13 @@ def sphericity(region): class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -442,13 +449,13 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -479,7 +486,7 @@ def __init__(self, widget_parent): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 0baa6373..dedade61 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -618,7 +618,9 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct + instance_labels = np.swapaxes( + instance_labels, 0, 2 + ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index c943fb52..436751b2 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -607,7 +607,9 @@ def start(self): self.instance_config = config.InstanceSegConfig( enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], ) self.post_process_config = config.PostProcessConfig( @@ -876,7 +878,9 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) viewer.add_labels(result.instance_labels, name=name) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 15e48f6e..28f7d314 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -123,11 +123,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class PostProcessConfig: """Class to record params for post processing diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..1aa52932 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -10,11 +10,14 @@ PERCENT_CORRECT = 0.7 + @dataclass class LabelInfo: gt_index: int model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + best_model_label_coverage: float = ( + 0.0 # ratio of pixels of the gt label correctly labelled + ) overall_gt_label_coverage: float = 0.0 # true positive ration of the model def get_correct_ratio(self): @@ -24,16 +27,25 @@ def get_correct_ratio(self): else: return None + def eval_model(gt_labels, model_labels, print_report=False): - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + report_list, new_labels, fused_labels = create_label_report( + gt_labels, model_labels + ) per_label_perfs = [] for report in report_list: if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + log.info( + f"Label {report.gt_index} : {report.model_labels_id_and_status}" + ) + log.info( + f"Best model label coverage : {report.best_model_label_coverage}" + ) + log.info( + f"Overall gt label coverage : {report.overall_gt_label_coverage}" + ) perf = report.get_correct_ratio() if perf is not None: @@ -43,8 +55,6 @@ def eval_model(gt_labels, model_labels, print_report=False): return per_label_perfs.mean(), new_labels, fused_labels - - def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -63,7 +73,6 @@ def create_label_report(gt_labels, model_labels): The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ - map_labels_existing = [] map_fused_neurons = {} "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" @@ -135,9 +144,7 @@ def create_label_report(gt_labels, model_labels): # log.debug(ratio) ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) + info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() best_model_lab_id = model_lab_on_gt[ np.argmax(ratio_model_lab_for_given_gt_lab) @@ -369,9 +376,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: From ab8e078a04d644720787919d50c73d9e9858eb24 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 448/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 564 +++++++++--------- notebooks/assess_instance.ipynb | 290 ++++----- 2 files changed, 385 insertions(+), 469 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 1aa52932..3082e79f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,275 +1,15 @@ import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = ( - 0.0 # ratio of pixels of the gt label correctly labelled - ) - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report( - gt_labels, model_labels - ) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info( - f"Label {report.gt_index} : {report.model_labels_id_and_status}" - ) - log.info( - f"Best model label coverage : {report.best_model_label_coverage}" - ) - log.info( - f"Overall gt label coverage : {report.overall_gt_label_coverage}" - ) - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False ): """Evaluate the model performance. Parameters @@ -278,7 +18,7 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - do_print : bool + print_details : bool If True, print the results. visualize : bool If True, visualize the results. @@ -305,7 +45,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels + labels, model_labels, threshold_correct ) # calculate the number of neurons individually found @@ -351,33 +91,30 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") + log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") + log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") + log.info( + f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + ) log.info( - "Mean true positive ratio of the model: ", + f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" ) - log.info(mean_true_positive_ratio_model) log.info( - "Mean ratio of the neurons pixels correctly labelled: ", + f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" ) - log.info(mean_ratio_pixel_found) log.info( - "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" ) - log.info(mean_ratio_pixel_found_fused) log.info( - "Mean true positive ratio of the model for fused neurons: ", + f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" ) - log.info(mean_true_positive_ratio_model_fused) - log.info("Mean ratio of false pixel in artefacts: ") - log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -436,6 +173,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file @@ -464,6 +276,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons # if __name__ == "__main__": # """ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index d521c395..4bf89452 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,9 +4,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -22,6 +19,7 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", + " to_semantic,\n", ")" ] }, @@ -29,9 +27,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -50,12 +45,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -72,9 +69,7 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" + "viewer.add_labels(gt_labels_resized, name=\"gt\")" ] }, { @@ -84,9 +79,33 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -98,7 +117,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -110,48 +143,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "2\n" ] - }, - { - "data": { - "text/plain": [ - "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)\n", + "print(np.unique(gt_labels_resized).shape[0])" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -162,23 +168,22 @@ { "data": { "text/plain": [ - "dtype('int32')" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized,thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -190,21 +195,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:05,891 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -217,18 +225,10 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" + "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -251,21 +251,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:05,995 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -278,25 +281,17 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", + " prediction_resized, thres_small=2, rem_seed_thres=1\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -304,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -318,24 +313,24 @@ "(25, 64, 64)" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", + "voronoi = remove_small_objects(voronoi, 2)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { @@ -349,7 +344,7 @@ "dtype('int64')" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -360,104 +355,35 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(voronoi, return_counts=True)" + "# np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(gt_labels_resized, return_counts=True)" + "# np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { @@ -469,21 +395,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:06,360 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -496,18 +425,10 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -518,14 +439,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, "outputs": [], From bb43b543e2712db3d5562fdfc3288898272046e9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 449/577] Added pre-commit hooks --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 92aae176..a7dd1570 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pre-commit pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 ruff From 0f44f38d649683dd244eeafb3387d7d12800e6cd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 450/577] Enfore pre-commit style --- .gitignore | 5 -- .../_tests/test_plugin_inference.py | 2 - .../code_models/instance_segmentation.py | 8 +- .../code_plugins/plugin_model_inference.py | 3 - .../code_plugins/plugin_utilities.py | 4 +- napari_cellseg3d/config.py | 3 - .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 23 ++++-- notebooks/assess_instance.ipynb | 79 +++++++++++++------ 10 files changed, 76 insertions(+), 53 deletions(-) diff --git a/.gitignore b/.gitignore index 7460d861..4eb18db2 100644 --- a/.gitignore +++ b/.gitignore @@ -107,8 +107,3 @@ notebooks/full_plot.html notebooks/instance_test.ipynb *.prof -#include test data -!napari_cellseg3d/_tests/res/test.tif -!napari_cellseg3d/_tests/res/test.png -!napari_cellseg3d/_tests/res/test_labels.tif -cov.syspath.txt diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 68ab2067..f78cbbf4 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -13,8 +13,6 @@ ) from napari_cellseg3d.config import MODEL_LIST - - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 5881966d..dcf9397e 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -1,5 +1,3 @@ -from __future__ import division -from __future__ import print_function from dataclasses import dataclass from functools import partial from typing import List @@ -11,6 +9,7 @@ from skimage.morphology import remove_small_objects from skimage.segmentation import watershed from tifffile import imread + # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -570,14 +569,13 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 436751b2..cb406931 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -15,9 +15,6 @@ InstanceWidgets, ) from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 6e1a606a..cbc2103d 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -5,7 +5,9 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget # local import napari_cellseg3d.interface as ui diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 28f7d314..71997786 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,10 +6,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3082e79f..a972fa69 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -5,11 +5,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -91,9 +95,15 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") - log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") - log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) if print_details: log.info(f"Neurons found: {neurons_found}") @@ -131,7 +141,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -276,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 4bf89452..b8810301 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -47,7 +47,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -96,7 +96,10 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", - "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" ] }, { @@ -145,7 +148,7 @@ "text": [ "(25, 64, 64)\n", "(25, 64, 64)\n", - "2\n" + "125\n" ] } ], @@ -168,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -177,7 +180,7 @@ } ], "source": [ - "connected = binary_connected(prediction_resized,thres_small=2)\n", + "connected = binary_connected(prediction_resized, thres_small=2)\n", "viewer.add_labels(connected, name=\"connected\")" ] }, @@ -195,24 +198,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,891 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -225,7 +228,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -251,24 +262,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,995 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -281,7 +292,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -395,24 +414,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,360 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -425,7 +444,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 15, From e722ae2334b343f8c962b7a672e305312fa2c9fc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 451/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4eb18db2..29ad584b 100644 --- a/.gitignore +++ b/.gitignore @@ -107,3 +107,4 @@ notebooks/full_plot.html notebooks/instance_test.ipynb *.prof + From 2204aa96ac1fce7e4cd8db3de135837723d93511 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 452/577] Version bump --- napari_cellseg3d/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index be8123e4..736c7f72 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1,2 @@ -__version__ = "0.0.3rc1" +__version__ = "0.0.2rc6" + From 27e1ee40bf4307020e38b18eb5e320490984a665 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 453/577] Updated project files --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e39a7522..38a03414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ dev = [ "black", "ruff", "pre-commit", - "tuna", + ] docs = [ "sphinx", @@ -127,3 +127,4 @@ onnx-gpu = [ "onnx", "onnxruntime-gpu" ] + From 576cc192803a0785bf4be1247fde00fb668b57ca Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 09:45:17 +0200 Subject: [PATCH 454/577] Fixed missing parent error --- napari_cellseg3d/code_models/instance_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index dcf9397e..261c7f0a 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -479,7 +479,7 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent): + def __init__(self, widget_parent=None): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, From 29081566f62751ac81cb78c53ae2de2feb5d023e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 455/577] Fixed wrong value in instance sliders --- .../code_models/instance_segmentation.py | 13 +++++++++---- .../code_plugins/plugin_model_inference.py | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 261c7f0a..453f50cf 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -162,6 +162,9 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -559,7 +562,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -569,8 +572,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() @@ -595,7 +600,7 @@ def run_method(self, volume): """ method = self.methods[self.method_choice.currentText()] - return method.run_method_on_channels(volume) + return method.run_method(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index cb406931..27419f80 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -195,6 +195,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) + self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## From d926a2b15482c35c8c77767f4f23b64626e8ce0f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 456/577] Removing dask-image --- .gitignore | 1 + napari_cellseg3d/utils.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 29ad584b..d6ce41a3 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,4 @@ notebooks/instance_test.ipynb *.prof +*.prof diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 663872c4..31ea1a65 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -5,7 +5,8 @@ import napari import numpy as np -from monai.transforms import Zoom + +# from dask import delayed from skimage import io from skimage.filters import gaussian from tifffile import imread, imwrite @@ -579,6 +580,10 @@ def load_images( "Loading as folder not implemented yet. Use napari to load as folder" ) # images_original = dask_imread(filename_pattern_original) + else: + images_original = tfl_imread( + filename_pattern_original + ) # tifffile imread return imread(filename_pattern_original) # tifffile imread From a1926aed9c79e87f019359e947de1c78ad844409 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 457/577] Fixed erroneous dtype conversion --- .../code_models/instance_segmentation.py | 3 ++- .../code_plugins/plugin_convert.py | 25 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 453f50cf..cefeea76 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -161,7 +161,7 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - semantic = np.squeeze(volume) + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) @@ -512,6 +512,7 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + ################ # For debugging # import napari diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 23cc581d..edca598c 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -95,19 +95,18 @@ def _start(self): f"isotropic_{layer.name}", ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): From 8c5bd105eb78b20cbb6305a21ba771737cedc906 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 458/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 60c25ccc..dd8e4955 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,8 @@ +from pathlib import Path +from tifffile import imread import numpy as np -from numpy.random import PCG64, Generator +from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, Utilities, @@ -13,9 +15,10 @@ def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) - image = rand_gen.random((10, 10, 10)).astype(np.uint8) - image_layer = view.add_image(image, name="image") - label_layer = view.add_labels(image.astype(np.uint8), name="labels") + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + image = imread(im_path) + view.add_image(image) + view.add_labels(image.astype(np.uint8)) view.window.add_dock_widget(widget) view.dims.ndisplay = 3 @@ -24,11 +27,4 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) - if utils_name == "Convert to instance labels": - # to avoid issues with Voronoi-Otsu missing runtime - menu = widget.utils_widgets[i].instance_widgets.method_choice - menu.setCurrentIndex(menu.currentIndex() + 1) - - assert len(image_layer.data.shape) == 3 - assert len(label_layer.data.shape) == 3 widget.utils_widgets[i]._start() From c11663213cf97e3b7427dcc9fab0fce7c15b6a01 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:38:13 +0200 Subject: [PATCH 459/577] Temporary test action patch --- .github/workflows/test_and_deploy.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index bb3662e8..5401bfd0 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -8,12 +8,14 @@ on: branches: - main - npe2 + - cy/voronoi-otsu tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - npe2 + - cy/voronoi-otsu workflow_dispatch: jobs: From 49aa95b36e8480fa74582823e3ef27a68d913056 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:50:16 +0200 Subject: [PATCH 460/577] Update plugin_convert.py --- .../code_plugins/plugin_convert.py | 73 ++++++++++++++++++- 1 file changed, 69 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index edca598c..5f66835b 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -18,6 +18,71 @@ logger = utils.LOGGER +def save_folder(results_path, folder_name, images, image_paths): + """ + Saves a list of images in a folder + + Args: + results_path: Path to the folder containing results + folder_name: Name of the folder containing results + images: List of images to save + image_paths: list of filenames of images + """ + results_folder = results_path / Path(folder_name) + results_folder.mkdir(exist_ok=False, parents=True) + + for file, image in zip(image_paths, images): + path = results_folder / Path(file).name + + imwrite( + path, + image, + ) + logger.info(f"Saved processed folder as : {results_folder}") + + +def save_layer(results_path, image_name, image): + """ + Saves an image layer at the specified path + + Args: + results_path: path to folder containing result + image_name: image name for saving + image: data array containing image + + Returns: + + """ + path = str(results_path / Path(image_name)) # TODO flexible filetype + logger.info(f"Saved as : {path}") + imwrite(path, image) + + +def show_result(viewer, layer, image, name): + """ + Adds layers to a viewer to show result to user + + Args: + viewer: viewer to add layer in + layer: type of the original layer the operation was run on, to determine whether it should be an Image or Labels layer + image: the data array containing the image + name: name of the added layer + + Returns: + + """ + if isinstance(layer, napari.layers.Image): + logger.debug("Added resulting image layer") + viewer.add_image(image, name=name) + elif isinstance(layer, napari.layers.Labels): + logger.debug("Added resulting label layer") + viewer.add_labels(image, name=name) + else: + warnings.warn( + f"Results not shown, unsupported layer type {type(layer)}" + ) + + class AnisoUtils(BasePluginFolder): """Class to correct anisotropy in images""" @@ -73,7 +138,7 @@ def _build(self): ) def _start(self): - utils.mkdir_from_str(self.results_path) + self.results_path.mkdir(exist_ok=True, parents=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -172,7 +237,7 @@ def _build(self): return container def _start(self): - utils.mkdir_from_str(self.results_path) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -339,7 +404,7 @@ def _build(self): ) def _start(self): - utils.mkdir_from_str(self.results_path) + self.results_path.mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -434,7 +499,7 @@ def _build(self): return container def _start(self): - utils.mkdir_from_str(self.results_path) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.binarize_counter.value() if self.layer_choice: From 54f5910f79f3ab16df2a3be17c89cca29987d472 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:02:47 +0200 Subject: [PATCH 461/577] Update tox.ini Added pocl for testing on GH Actions --- tox.ini | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 0605fc8c..8aaea25e 100644 --- a/tox.ini +++ b/tox.ini @@ -36,8 +36,7 @@ deps = magicgui pytest-qt qtpy - git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf -; pyopencl[pocl] + pocl ; opencv-python extras = crf usedevelop = true From 242bd5780186404f7685b193c7172159bf3d3529 Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 462/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 8aaea25e..22b09bf5 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl + pocl-binary-distribution ; opencv-python extras = crf usedevelop = true From 7187bfa720e768138f9c0b691f5b7f42c88ffc2d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:18:52 +0200 Subject: [PATCH 463/577] Found existing pocl --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 22b09bf5..82fa219b 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl-binary-distribution + pyopencl[pocl] ; opencv-python extras = crf usedevelop = true From 7b79d5e5dc63b761fad72dddf93b721da112e6d0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:41:23 +0200 Subject: [PATCH 464/577] Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime --- napari_cellseg3d/_tests/test_plugin_utils.py | 5 +++++ tox.ini | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index dd8e4955..29d96c9b 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -27,4 +27,9 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + widget.utils_widgets[i]._start() diff --git a/tox.ini b/tox.ini index 82fa219b..4b04a5bc 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pyopencl[pocl] +; pyopencl[pocl] ; opencv-python extras = crf usedevelop = true From 6e81c3fee7dbaf429fc73de7cf47dc95537f5ac1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 465/577] Relabeling tests --- .gitignore | 6 +- .../_tests/test_labels_correction.py | 11 +-- .../dev_scripts/artefact_labeling.py | 93 +++++++++---------- .../dev_scripts/correct_labels.py | 75 ++++++++++----- 4 files changed, 105 insertions(+), 80 deletions(-) diff --git a/.gitignore b/.gitignore index d6ce41a3..df67a187 100644 --- a/.gitignore +++ b/.gitignore @@ -107,5 +107,7 @@ notebooks/full_plot.html notebooks/instance_test.ipynb *.prof - -*.prof +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index b4f13238..9d4e7801 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,7 +1,6 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl @@ -37,16 +36,16 @@ def test_correct_labels(): ) -def test_relabel(): +def test_relabel(make_napari_viewer): + viewer = make_napari_viewer() cl.relabel( str(image_path), str(labels_path), go_fast=True, + viewer=viewer, test=True, ) def test_evaluate_model_performance(): - el.evaluate_model_performance( - labels, labels, print_details=True, visualize=False - ) + el.evaluate_model_performance(labels, labels, print_details=True) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 9a344545..bf724a46 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,7 +1,5 @@ import numpy as np -from tifffile import imread -from tifffile import imwrite -from pathlib import Path +from tifffile import imwrite, imread import scipy.ndimage as ndimage import os import napari @@ -64,7 +62,7 @@ def map_labels(labels, artefacts): def make_labels( - path_image, + image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -76,7 +74,7 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - path_image : str + image : str Path to image. path_labels_out : str Path of the output labelled image. @@ -96,7 +94,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(path_image) + # image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -126,28 +124,26 @@ def make_labels( ) -def select_image_by_labels( - path_image, path_labels, path_image_out, label_values -): +def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - path_image : str - Path to image. - path_labels : str - Path to labels. + image : np.array + image. + labels : np.array + labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - image = imread(path_image) - labels = imread(path_labels) + # image = imread(image) + # labels = imread(labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smalles cube that contains all the none zero pixel of an 3d image +# select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -165,16 +161,15 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(path_image, path_image_out): +def crop_image_path(image, path_image_out): """Crop image. Parameters ---------- - path_image : str - Path to image. + image : np.array + image path_image_out : str Path of the output image. """ - image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -307,8 +302,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image_path, - labels_path, + image, + labels, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -317,10 +312,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image_path : str - Path to image file. - labels_path : str - Path to label image file with each neurons labelled as a different value. + image : np.array + image for artefact detection. + labels : np.array + label image array with each neurons labelled as a different int value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -330,9 +325,6 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ - image = imread(image_path) - labels = imread(labels_path) - artefacts = make_artefact_labels( image, labels, @@ -352,11 +344,12 @@ def visualize_images(paths): Parameters ---------- paths : list - List of paths to images to visualize. + List of images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - viewer.add_image(imread(path), name=os.path.basename(path)) + image = imread(path) + viewer.add_image(image) # wait for the user to close the viewer napari.run() @@ -416,22 +409,22 @@ def create_artefact_labels_from_folder( ) -if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] - print(f"REPO PATH : {repo_path}") - paths = [ - "dataset_clean/cropped_visual/train", - "dataset_clean/cropped_visual/val", - "dataset_clean/somatomotor", - "dataset_clean/visual_tif", - ] - for data_path in paths: - path = str(repo_path / data_path) - print(path) - create_artefact_labels_from_folder( - path, - do_visualize=False, - threshold_artefact_brightness_percent=20, - threshold_artefact_size_percent=1, - contrast_power=20, - ) +# if __name__ == "__main__": +# repo_path = Path(__file__).resolve().parents[1] +# print(f"REPO PATH : {repo_path}") +# paths = [ +# "dataset_clean/cropped_visual/train", +# "dataset_clean/cropped_visual/val", +# "dataset_clean/somatomotor", +# "dataset_clean/visual_tif", +# ] +# for data_path in paths: +# path = str(repo_path / data_path) +# print(path) +# create_artefact_labels_from_folder( +# path, +# do_visualize=False, +# threshold_artefact_brightness_percent=20, +# threshold_artefact_size_percent=1, +# contrast_power=20, +# ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index cd09754e..50f2e47a 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -4,6 +4,7 @@ import scipy.ndimage as ndimage import napari from pathlib import Path +from functools import partial import time import warnings from napari.qt.threading import thread_worker @@ -85,13 +86,16 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact): +def ask_labels(unique_artefact, test=False): global returns returns = [] - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + if not test: + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + else: + i_labels_to_add_tmp = [0] if i_labels_to_add_tmp == [0]: print("no label added") @@ -135,7 +139,13 @@ def ask_labels(unique_artefact): def relabel( - image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 + image_path, + label_path, + go_fast=False, + check_for_unicity=True, + delay=0.3, + viewer=None, + test=False, ): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters @@ -150,6 +160,8 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 + viewer : napari.Viewer, optional + the napari viewer, by default None """ global returns @@ -164,9 +176,10 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + if not test: + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -186,15 +199,22 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t = threading.Thread( + target=partial(ask_labels, test=test), args=(unique_artefact,) + ) t.start() artefact_copy = np.where( np.isin(artefact, i_labels_to_add), 0, artefact ) - viewer = napari.view_image(image) + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + viewer.add_image(image, name="image") viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - napari.run() + if not test: + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones @@ -205,15 +225,26 @@ def relabel( np.isin(artefact, i_labels_to_add_tmp), artefact, 0 ) print("these labels will be added") - viewer = napari.view_image(image) - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") + if test: + viewer.close() + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + if not test: + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if test: + revert = "n" + viewer.close() if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) + if test: + break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -334,9 +365,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") - image_path = str(im_path / "image.tif") - gt_labels_path = str(im_path / "labels.tif") - - relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# if __name__ == "__main__": +# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") +# image_path = str(im_path / "image.tif") +# gt_labels_path = str(im_path / "labels.tif") +# +# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) From 75246f4a63478baf18faa378ce3ef4d54f66c2fd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 466/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f9fe2853..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,14 +5,11 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - id: check-yaml - - id: check-added-large-files - - id: check-toml -# - repo: https://github.com/pycqa/isort -# rev: 5.12.0 -# hooks: -# - id: isort -# args: ["--profile", "black", --line-length=79] + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From 8b4f5ba13ec5bca94b2dda2d71a1040793ef695b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 467/577] Run full suite of pre-commit hooks --- README.md | 2 +- napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/_tests/pytest.ini | 2 +- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/instance_segmentation.py | 3 +-- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 2 +- 9 files changed, 29 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ca8d0931..ece6c6f4 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 814cca2e..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 \ No newline at end of file +qt_api=pyqt5 diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 29d96c9b..f0470990 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import ( diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index cefeea76..db0ffa79 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -15,8 +15,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -512,7 +512,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index a972fa69..ee9919b6 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,7 @@ +import napari import numpy as np import pandas as pd from tqdm import tqdm -import napari from napari_cellseg3d.utils import LOGGER as log From d9dc7759902ea03fa8515da10277fa31d5b356f6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:08:38 +0200 Subject: [PATCH 468/577] Enforce style --- napari_cellseg3d/__init__.py | 1 - napari_cellseg3d/_tests/test_plugin_inference.py | 1 + napari_cellseg3d/_tests/test_plugin_utils.py | 5 +---- .../code_models/instance_segmentation.py | 8 +++++--- napari_cellseg3d/code_models/models/unet/model.py | 1 + napari_cellseg3d/code_plugins/plugin_convert.py | 2 ++ napari_cellseg3d/code_plugins/plugin_review.py | 1 + napari_cellseg3d/code_plugins/plugin_utilities.py | 13 +++++-------- napari_cellseg3d/config.py | 1 - napari_cellseg3d/interface.py | 5 +---- pyproject.toml | 1 - 11 files changed, 17 insertions(+), 22 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 736c7f72..11e8de0e 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1,2 +1 @@ __version__ = "0.0.2rc6" - diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index f78cbbf4..779f5094 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -13,6 +13,7 @@ ) from napari_cellseg3d.config import MODEL_LIST + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index f0470990..5d5ada20 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -4,10 +4,7 @@ from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import ( - UTILITIES_WIDGETS, - Utilities, -) +from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS rand_gen = Generator(PCG64(12345)) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index db0ffa79..c19aa3e8 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from functools import partial from typing import List + import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget @@ -10,14 +11,15 @@ from skimage.segmentation import watershed from tifffile import imread -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index 9591d054..ee566be7 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -5,6 +5,7 @@ create_decoders, create_encoders, ) +from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 5f66835b..2dc8f07c 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,4 +1,5 @@ from pathlib import Path + import napari import numpy as np from qtpy.QtWidgets import QSizePolicy @@ -350,6 +351,7 @@ def _start(self): self.images_filepaths, ) + class ToInstanceUtils(BasePluginFolder): """ Widget to convert semantic labels to instance labels diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index dd98bcd7..77149208 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -16,6 +16,7 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui +from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index cbc2103d..127ad0d7 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -11,14 +11,11 @@ # local import napari_cellseg3d.interface as ui -from napari_cellseg3d.code_plugins.plugin_convert import ( - AnisoUtils, - RemoveSmallUtils, - ThresholdUtils, - ToInstanceUtils, - ToSemanticUtils, -) -from napari_cellseg3d.code_plugins.plugin_crf import CRFWidget +from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils +from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils +from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils +from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils +from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 71997786..f3efabc3 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -8,7 +8,6 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 09269497..d8476bc1 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -6,11 +6,8 @@ import napari # Qt -from qtpy import QtCore - # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt +from qtpy import QtCore from qtpy.QtCore import QObject from qtpy.QtCore import Qt from qtpy.QtCore import QUrl diff --git a/pyproject.toml b/pyproject.toml index 38a03414..c9a9d942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,4 +127,3 @@ onnx-gpu = [ "onnx", "onnxruntime-gpu" ] - From 234dfaade75cbf65bb311877cbc47a5ce7077225 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:12:49 +0100 Subject: [PATCH 469/577] Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling --- napari_cellseg3d/_tests/test_interface.py | 7 +-- .../_tests/test_plugin_inference.py | 2 + napari_cellseg3d/_tests/test_training.py | 21 ++------ .../code_models/models/model_test.py | 34 ++++++------- napari_cellseg3d/code_models/workers.py | 37 +++++++++++--- .../code_plugins/plugin_convert.py | 1 + .../code_plugins/plugin_model_inference.py | 14 +++--- .../code_plugins/plugin_model_training.py | 49 +++++++++++++++++++ napari_cellseg3d/config.py | 16 +++--- napari_cellseg3d/interface.py | 6 ++- requirements.txt | 4 +- 11 files changed, 125 insertions(+), 66 deletions(-) diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py index 08e0e675..840f7a93 100644 --- a/napari_cellseg3d/_tests/test_interface.py +++ b/napari_cellseg3d/_tests/test_interface.py @@ -1,6 +1,7 @@ from napari_cellseg3d.interface import AnisotropyWidgets, Log + def test_log(qtbot): log = Log() log.print_and_log("test") @@ -12,9 +13,3 @@ def test_log(qtbot): assert log.toPlainText() == "\ntest2" qtbot.add_widget(log) - - -def test_zoom_factor(): - resolution = [10.0, 10.0, 5.0] - zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution) - assert zoom == [1, 1, 0.5] diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 779f5094..5296ef96 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -14,6 +14,8 @@ from napari_cellseg3d.config import MODEL_LIST +def test_inference(make_napari_viewer, qtbot): + def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index e7f1e07b..5b642e77 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -39,23 +39,10 @@ def test_training(make_napari_viewer, qtbot): widget.model_choice.addItem("test") widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) - worker_config = widget._set_worker_config() - worker = widget._create_worker_from_config(worker_config) - worker.config.train_data_dict = [{"image": im_path, "label": im_path}] - worker.config.val_data_dict = [{"image": im_path, "label": im_path}] - worker.config.max_epochs = 1 - worker.log_parameters() - res = next(worker.train()) - - assert isinstance(res, TrainingReport) - - # def on_error(e): - # print(e) - # assert False - # - # with qtbot.waitSignal( - # signal=widget.worker.finished, timeout=10000, raising=True - # ) as blocker: + # widget.start() + # assert widget.worker is not None + + # with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds. # blocker.connect(widget.worker.errored) # widget.worker.error_signal.connect(on_error) # widget.worker.train() diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 28f3a05b..d34f29e9 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -3,30 +3,30 @@ class TestModel(nn.Module): - use_default_training = True - weights_file = "test.pth" - - def __init__(self, **kwargs): + def __init__(self): super().__init__() self.linear = nn.Linear(8, 8) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) - # def get_output(self, _, input): - # return input + def get_net(self): + return self - # def get_validation(self, val_inputs): - # return val_inputs + def get_output(self, _, input): + return input + def get_validation(self, val_inputs): + return val_inputs -if __name__ == "__main__": - model = TestModel() - model.train() - model.zero_grad() - from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR - torch.save( - model.state_dict(), - PRETRAINED_WEIGHTS_DIR + f"/{TestModel.weights_file}", - ) +# if __name__ == "__main__": +# +# model = TestModel() +# model.train() +# model.zero_grad() +# from napari_cellseg3d.config import WEIGHTS_DIR +# torch.save( +# model.state_dict(), +# WEIGHTS_DIR + f"/{get_weights_file()}" +# ) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index dedade61..3c4cfcaa 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -54,6 +54,8 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d import utils + +# local from napari_cellseg3d.code_models.model_instance_seg import ImageStats from napari_cellseg3d.code_models.model_instance_seg import volume_stats @@ -503,10 +505,33 @@ def model_output( ): inputs = inputs.to("cpu") - # def model_output(inputs): - # return post_process_transforms( - # self.config.model_info.get_model().get_output(model, inputs) - # ) + model_output = lambda inputs: post_process_transforms( + self.config.model_info.get_model().get_output( + model, inputs + ) # TODO(cyril) refactor those functions + ) + + def model_output(inputs): + return post_process_transforms( + self.config.model_info.get_model().get_output(model, inputs) + ) + + if self.config.keep_on_cpu: + dataset_device = "cpu" + else: + dataset_device = self.config.device + + window_size = self.config.sliding_window_config.window_size + window_overlap = self.config.sliding_window_config.window_overlap + + # FIXME + # import sys + + # old_stdout = sys.stdout + # old_stderr = sys.stderr + + # sys.stdout = self.downloader.log_widget + # sys.stdout = self.downloader.log_widget dataset_device = ( "cpu" if self.config.keep_on_cpu else self.config.device @@ -681,8 +706,8 @@ def instance_seg( if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) + method = self.config.post_process_config.instance + instance_labels = method.run_method(to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 2dc8f07c..f6ce6552 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -11,6 +11,7 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder MAX_W = ui.UTILS_MAX_WIDTH diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 27419f80..73687bc0 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -19,6 +19,11 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -603,12 +608,9 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[ - self.instance_widgets.method_choice.currentText() - ], - ) + self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ + self.instance_widgets.method_choice.currentText() + ] self.post_process_config = config.PostProcessConfig( zoom=zoom_config, diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 3e666dcc..b4ca9848 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -810,6 +810,55 @@ def start(self): self.data = None raise err + model_config = config.ModelInfo( + name=self.model_choice.currentText() + ) + + self.weights_config.path = self.weights_config.path + self.weights_config.custom = self.custom_weights_choice.isChecked() + self.weights_config.use_pretrained = ( + not self.use_transfer_choice.isChecked() + ) + + deterministic_config = config.DeterministicConfig( + enabled=self.use_deterministic_choice.isChecked(), + seed=self.box_seed.value(), + ) + + validation_percent = ( + self.validation_percent_choice.slider_value / 100 + ) + + results_path_folder = Path( + self.results_path + + f"/{model_config.name}_{utils.get_date_time()}" + ) + Path(results_path_folder).mkdir( + parents=True, exist_ok=False + ) # avoid overwrite where possible + + patch_size = [w.value() for w in self.patch_size_widgets] + + logger.debug("Loading config...") + self.worker_config = config.TrainingWorkerConfig( + device=self.get_device(), + model_info=model_config, + weights_info=self.weights_config, + train_data_dict=self.data, + validation_percent=validation_percent, + max_epochs=self.epoch_choice.value(), + loss_function=self.get_loss(self.loss_choice.currentText()), + learning_rate=float(self.learning_rate_choice.currentText()), + validation_interval=self.val_interval_choice.value(), + batch_size=self.batch_choice.slider_value, + results_path_folder=str(results_path_folder), + sampling=self.patch_choice.isChecked(), + num_samples=self.sample_choice_slider.slider_value, + sample_size=patch_size, + do_augmentation=self.augment_choice.isChecked(), + deterministic_config=deterministic_config, + ) # TODO(cyril) continue to put params in config + self.config = config.TrainerConfig( save_as_zip=self.zip_choice.isChecked() ) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index f3efabc3..2049879e 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,8 +6,6 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod - # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -15,6 +13,12 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.model_instance_seg import ( + ConnectedComponents, + Watershed, + VoronoiOtsu, + InstanceMethod, +) from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -120,12 +124,6 @@ class Zoom: zoom_values: List[float] = None -@dataclass -class InstanceSegConfig: - enabled: bool = False - method: InstanceMethod = None - - @dataclass class PostProcessConfig: """Class to record params for post processing @@ -138,7 +136,7 @@ class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceSegConfig = InstanceSegConfig() + instance: InstanceMethod = None @dataclass diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d8476bc1..8839d041 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -10,6 +10,8 @@ from qtpy import QtCore from qtpy.QtCore import QObject from qtpy.QtCore import Qt + +# from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices @@ -1089,9 +1091,9 @@ def __init__( def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.01) - else: self.setSingleStep(0.1) + else: + self.setSingleStep(1) @property def tooltips(self): diff --git a/requirements.txt b/requirements.txt index a7dd1570..9dbe5a00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,9 +14,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pre-commit -pyclesperanto-prototype>=0.22.0 -pysqlite3 +pyclesperanto-prototype >=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 ruff From a46b590bf954935fa737ac24703e915bc28ecc6f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 13 Mar 2023 15:28:18 +0100 Subject: [PATCH 470/577] Disabled small removal in Voronoi-Otsu --- napari_cellseg3d/code_models/instance_segmentation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index c19aa3e8..d8d076b5 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -158,15 +158,13 @@ def voronoi_otsu( spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation + Returns: Instance segmentation labels from Voronoi-Otsu method """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - # semantic = np.squeeze(volume) - logger.debug( - f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" - ) + semantic = np.squeeze(volume) instance = cle.voronoi_otsu_labeling( volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -490,7 +488,6 @@ def __init__(self, widget_parent=None): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ From 3ce6f281e7d37f9c715b042980fed45c0492ea4a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 14 Mar 2023 08:20:04 +0100 Subject: [PATCH 471/577] Added new docs for instance seg --- .../code_models/instance_segmentation.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index d8d076b5..b6e59b23 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -39,14 +39,11 @@ def __init__( ): """ Methods for instance segmentation - Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function - widget_parent: parent for the declared widgets - """ self.name = name self.function = function @@ -402,10 +399,10 @@ def sphericity(region): ) -class Watershed(InstanceMethod): +class Watershed(InstanceMethod, metaclass=Singleton): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent=None): + def __init__(self): super().__init__( name=WATERSHED, function=binary_watershed, @@ -448,10 +445,10 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod): +class ConnectedComponents(InstanceMethod, metaclass=Singleton): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent=None): + def __init__(self): super().__init__( name=CONNECTED_COMP, function=binary_connected, @@ -479,10 +476,10 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod): +class VoronoiOtsu(InstanceMethod, metaclass=Singleton): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent=None): + def __init__(self): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, @@ -603,7 +600,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VORONOI_OTSU: VoronoiOtsu, - WATERSHED: Watershed, - CONNECTED_COMP: ConnectedComponents, + VoronoiOtsu().name: VoronoiOtsu, + Watershed().name: Watershed, + ConnectedComponents().name: ConnectedComponents, } From b73633751847c0067bffef8e2edd0a23911a0622 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 09:50:45 +0100 Subject: [PATCH 472/577] Docs + UI update - Updated welcome/README - Changed step for DoubleCounter --- README.md | 2 +- docs/res/welcome.rst | 2 -- napari_cellseg3d/interface.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ece6c6f4..ca8d0931 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 12a20630..045297f6 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -103,8 +103,6 @@ This plugin mainly uses the following libraries and software: * `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase -* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase - * A custom re-implementation of the `WNet model`_ by Xia and Kulis [#]_ .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 8839d041..1b44a889 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1091,9 +1091,9 @@ def __init__( def _update_step(self): if self.value() < 0.9: - self.setSingleStep(0.1) + self.setSingleStep(0.01) else: - self.setSingleStep(1) + self.setSingleStep(0.1) @property def tooltips(self): From ba53463bff85c231d6cbfc14399877428a766c4c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:07:33 +0100 Subject: [PATCH 473/577] Update requirements.txt Fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9dbe5a00..92aae176 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 -pyclesperanto-prototype >=0.22.0 +pyclesperanto-prototype>=0.22.0 dask-image>=0.6.0 matplotlib>=3.4.1 ruff From b286727123b85948a9d784a092608deaecab2f18 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:20:58 +0100 Subject: [PATCH 474/577] isort --- napari_cellseg3d/_tests/fixtures.py | 2 ++ napari_cellseg3d/_tests/test_plugin_inference.py | 8 +------- napari_cellseg3d/_tests/test_training.py | 5 +---- napari_cellseg3d/_tests/test_weight_download.py | 7 +++---- napari_cellseg3d/code_models/instance_segmentation.py | 11 ++++++----- napari_cellseg3d/code_plugins/plugin_convert.py | 1 - .../code_plugins/plugin_model_inference.py | 8 +++----- napari_cellseg3d/config.py | 11 +++++------ napari_cellseg3d/interface.py | 4 ++-- 9 files changed, 23 insertions(+), 34 deletions(-) diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index b3044799..da34ae9b 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -1,3 +1,5 @@ +import warnings + from qtpy.QtWidgets import QTextEdit from napari_cellseg3d.utils import LOGGER as logger diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 5296ef96..a17120c3 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,14 +3,8 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.instance_segmentation import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_inference import ( - InferenceResult, - Inferer, -) +from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer from napari_cellseg3d.config import MODEL_LIST diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 5b642e77..9adb3286 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -3,10 +3,7 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_training import ( - Trainer, - TrainingReport, -) +from napari_cellseg3d.code_plugins.plugin_model_training import Trainer from napari_cellseg3d.config import MODEL_LIST diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index be694d99..72dc939d 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,7 +1,6 @@ -from napari_cellseg3d.code_models.workers import ( - PRETRAINED_WEIGHTS_DIR, - WeightsDownloader, -) +from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import WeightsDownloader + diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index b6e59b23..ef181e3e 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -5,21 +5,22 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget +from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed +from skimage.transform import resize + +# from skimage.measure import mesh_surface_area +# from skimage.measure import marching_cubes from tifffile import imread from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes - - # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index f6ce6552..2dc8f07c 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -11,7 +11,6 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_instance_seg import threshold from napari_cellseg3d.code_models.model_instance_seg import to_semantic -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder MAX_W = ui.UTILS_MAX_WIDTH diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 73687bc0..478a8755 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -15,15 +15,13 @@ InstanceWidgets, ) from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.model_instance_seg import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 2049879e..f3e4e478 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,6 +6,11 @@ import napari import numpy as np +from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu +from napari_cellseg3d.code_models.model_instance_seg import Watershed + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR @@ -13,12 +18,6 @@ model_TRAILMAP_MS as TRAILMAP_MS, ) from napari_cellseg3d.code_models.models import model_VNet as VNet -from napari_cellseg3d.code_models.model_instance_seg import ( - ConnectedComponents, - Watershed, - VoronoiOtsu, - InstanceMethod, -) from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 1b44a889..c64cca19 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -8,10 +8,10 @@ # Qt # from qtpy.QtCore import QtWarningMsg from qtpy import QtCore -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt # from qtpy.QtCore import QtWarningMsg +from qtpy.QtCore import QObject +from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices From 85976b6c23eb67812e668f89b42f233759496823 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 10:40:06 +0100 Subject: [PATCH 475/577] Fix tests --- napari_cellseg3d/_tests/conftest.py | 1 - napari_cellseg3d/_tests/pytest.ini | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index bbfeff10..4d4a4007 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,5 +1,4 @@ import os - import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 45c3be1c..814cca2e 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 +qt_api=pyqt5 \ No newline at end of file From ef67ac19c74e85ba1d01d2ba74cba1ac942f423e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:10:56 +0100 Subject: [PATCH 476/577] Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init --- .../code_models/instance_segmentation.py | 45 ++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index ef181e3e..08eee806 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -18,8 +18,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import Singleton from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import LOGGER as logger # from napari_cellseg3d.utils import sphericity_volume_area @@ -36,7 +36,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None, + widget_parent: QWidget = None ): """ Methods for instance segmentation @@ -45,6 +45,7 @@ def __init__( function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + widget_parent: parent for the declared widgets """ self.name = name self.function = function @@ -56,14 +57,7 @@ def __init__( setattr( self, widget, - ui.Slider( - 0, - 100, - 1, - divide_factor=100, - text_label="", - parent=None, - ), + ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), ) self.sliders.append(getattr(self, widget)) @@ -400,16 +394,16 @@ def sphericity(region): ) -class Watershed(InstanceMethod, metaclass=Singleton): +class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent, + widget_parent=widget_parent ) self.sliders[0].label.setText("Foreground probability threshold") @@ -446,16 +440,16 @@ def run_method(self, image): ) -class ConnectedComponents(InstanceMethod, metaclass=Singleton): +class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self): + def __init__(self, widget_parent = None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent, + widget_parent=widget_parent ) self.sliders[0].label.setText("Foreground probability threshold") @@ -477,15 +471,16 @@ def run_method(self, image): ) -class VoronoiOtsu(InstanceMethod, metaclass=Singleton): +class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self): + def __init__(self, widget_parent): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, num_sliders=0, num_counters=2, + widget_parent=widget_parent ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ @@ -557,9 +552,8 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) - self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets ? + # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -570,14 +564,13 @@ def _build(self): group.layout.addWidget(counter) self.instance_widgets[name].append(counter) except RuntimeError as e: - logger.debug( - f"Caught runtime error {e}, most likely during testing" - ) + logger.debug(f"Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): + for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: @@ -601,7 +594,7 @@ def run_method(self, volume): INSTANCE_SEGMENTATION_METHOD_LIST = { - VoronoiOtsu().name: VoronoiOtsu, - Watershed().name: Watershed, - ConnectedComponents().name: ConnectedComponents, + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, } From 07da249b15d44ab87558aaca8ef22ecb0aedd682 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 15 Mar 2023 11:44:19 +0100 Subject: [PATCH 477/577] Fix inference --- .../code_models/instance_segmentation.py | 1 + napari_cellseg3d/code_models/workers.py | 8 +- .../code_plugins/plugin_model_inference.py | 17 +- napari_cellseg3d/config.py | 6 +- notebooks/assess_instance.ipynb | 479 +----------------- 5 files changed, 32 insertions(+), 479 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 08eee806..df86f0f9 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -552,6 +552,7 @@ def _build(self): try: for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class self.instance_widgets[name] = [] # moderately unsafe way to init those widgets if len(method_class.sliders) > 0: diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 3c4cfcaa..b94c8d23 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -643,9 +643,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes( - instance_labels, 0, 2 - ) # TODO(cyril) check if correct + instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -706,8 +704,8 @@ def instance_seg( if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - method = self.config.post_process_config.instance - instance_labels = method.run_method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 478a8755..201452ac 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -606,9 +606,10 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - self.instance_config = INSTANCE_SEGMENTATION_METHOD_LIST[ - self.instance_widgets.method_choice.currentText() - ] + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + ) self.post_process_config = config.PostProcessConfig( zoom=zoom_config, @@ -876,11 +877,13 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + method_name = self.worker_config.post_process_config.instance.method.name - viewer.add_labels(result.instance_labels, name=name) + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background + + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index f3e4e478..af5e8c3b 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -122,6 +122,10 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: InstanceMethod = None @dataclass class PostProcessConfig: @@ -135,7 +139,7 @@ class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() - instance: InstanceMethod = None + instance: InstanceSegConfig = InstanceSegConfig() @dataclass diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b8810301..40412282 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,500 +4,47 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "tags": [] + "collapsed": true }, "outputs": [], "source": [ - "import napari\n", "import numpy as np\n", - "from pathlib import Path\n", "from tifffile import imread\n", - "\n", - "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", - "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import (\n", - " binary_connected,\n", - " binary_watershed,\n", - " voronoi_otsu,\n", - " to_semantic,\n", - ")" + "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [] - }, + "execution_count": null, "outputs": [], - "source": [ - "viewer = napari.Viewer()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", - "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", - "\n", - "prediction = imread(prediction_path)\n", - "gt_labels = imread(gt_labels_path)\n", - "\n", - "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", - "gt_labels_resized = resize(gt_labels, zoom)\n", - "\n", - "\n", - "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "0.5817600487210719" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from napari_cellseg3d.utils import dice_coeff\n", - "\n", - "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "source": [], "metadata": { "collapsed": false, - "jupyter": { - "outputs_hidden": false + "pycharm": { + "name": "#%%\n" } - }, - "outputs": [], - "source": [ - "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", - "\n", - "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n", - "125\n" - ] - } - ], - "source": [ - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)\n", - "print(np.unique(gt_labels_resized).shape[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "connected = binary_connected(prediction_resized, thres_small=2)\n", - "viewer.add_labels(connected, name=\"connected\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, connected)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(68,\n", - " 43,\n", - " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "watershed = binary_watershed(\n", - " prediction_resized, thres_small=2, rem_seed_thres=1\n", - ")\n", - "viewer.add_labels(watershed)\n", - "eval.evaluate_model_performance(gt_labels_resized, watershed)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(25, 64, 64)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", - "\n", - "from skimage.morphology import remove_small_objects\n", - "\n", - "voronoi = remove_small_objects(voronoi, 2)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "dtype('int64')" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gt_labels_resized.dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# np.unique(voronoi, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# np.unique(gt_labels_resized, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [], - "source": [ - "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" - ] + } } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" + "pygments_lexer": "ipython2", + "version": "2.7.6" } }, "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat_minor": 0 +} \ No newline at end of file From 77405174324c97f4107bc2970dfe62a44fe9aa90 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 15:29:38 +0100 Subject: [PATCH 478/577] Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../code_models/instance_segmentation.py | 5 +- .../dev_scripts/artefact_labeling.py | 130 +++--- .../dev_scripts/correct_labels.py | 133 ++---- .../dev_scripts/evaluate_labels.py | 408 ++++-------------- notebooks/assess_instance.ipynb | 401 ++++++++++++++++- 5 files changed, 581 insertions(+), 496 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index df86f0f9..c65a2282 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -40,12 +40,14 @@ def __init__( ): """ Methods for instance segmentation + Args: name: Name of the instance segmentation method (for UI) function: Function to use for instance segmentation num_sliders: Number of Slider UI elements needed to set the parameters of the function num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function widget_parent: parent for the declared widgets + """ self.name = name self.function = function @@ -150,7 +152,6 @@ def voronoi_otsu( spot_sigma (float): parameter determining how close detected objects can be outline_sigma (float): determines the smoothness of the segmentation - Returns: Instance segmentation labels from Voronoi-Otsu method @@ -176,6 +177,8 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 + scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) + """ logger.debug( f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..875ca9b6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,17 +1,15 @@ -import os - -import napari import numpy as np -import scipy.ndimage as ndimage -from skimage.filters import threshold_otsu from tifffile import imread from tifffile import imwrite - -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed - +from pathlib import Path +import scipy.ndimage as ndimage +import os +import napari # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from skimage.filters import threshold_otsu """ New code by Yves Paychere @@ -46,9 +44,7 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append( - np.array([i, unique[np.argmax(counts)]]) - ) + map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -65,7 +61,7 @@ def map_labels(labels, artefacts): def make_labels( - image, + path_image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -77,7 +73,7 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - image : str + path_image : str Path to image. path_labels_out : str Path of the output labelled image. @@ -97,25 +93,21 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - # image = imread(image) + image = imread(path_image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( - np.max(image_contrasted) - np.min(image_contrasted) - ) + image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size( - labels, min_size=threshold_size, is_labeled=True - ) + labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -127,26 +119,26 @@ def make_labels( ) -def select_image_by_labels(image, labels, path_image_out, label_values): +def select_image_by_labels(path_image, path_labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - image : np.array - image. - labels : np.array - labels. + path_image : str + Path to image. + path_labels : str + Path to labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - # image = imread(image) - # labels = imread(labels) + image = imread(path_image) + labels = imread(path_labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smallest cube that contains all the non-zero pixels of a 3d image +# select the smalles cube that contains all the none zero pixel of an 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -164,15 +156,16 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(image, path_image_out): +def crop_image_path(path_image, path_image_out): """Crop image. Parameters ---------- - image : np.array - image + path_image : str + Path to image. path_image_out : str Path of the output image. """ + image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -220,9 +213,7 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile( - image[neurons], threshold_artefact_brightness_percent - ) + threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -253,9 +244,7 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile( - sizes, threshold_artefact_size_percent - ) + neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -305,8 +294,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image, - labels, + image_path, + labels_path, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -315,10 +304,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image : np.array - image for artefact detection. - labels : np.array - label image array with each neurons labelled as a different int value. + image_path : str + Path to image file. + labels_path : str + Path to label image file with each neurons labelled as a different value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -328,6 +317,9 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ + image = imread(image_path) + labels = imread(labels_path) + artefacts = make_artefact_labels( image, labels, @@ -347,12 +339,11 @@ def visualize_images(paths): Parameters ---------- paths : list - List of images to visualize. + List of paths to images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - image = imread(path) - viewer.add_image(image) + viewer.add_image(imread(path), name=os.path.basename(path)) # wait for the user to close the viewer napari.run() @@ -379,12 +370,8 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [ - f for f in os.listdir(path + "/labels") if f.endswith(".tif") - ] - path_images = [ - f for f in os.listdir(path + "/volumes") if f.endswith(".tif") - ] + path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] + path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] # sort the list path_labels.sort() path_images.sort() @@ -412,22 +399,23 @@ def create_artefact_labels_from_folder( ) -# if __name__ == "__main__": -# repo_path = Path(__file__).resolve().parents[1] -# print(f"REPO PATH : {repo_path}") -# paths = [ -# "dataset_clean/cropped_visual/train", -# "dataset_clean/cropped_visual/val", -# "dataset_clean/somatomotor", -# "dataset_clean/visual_tif", -# ] -# for data_path in paths: -# path = str(repo_path / data_path) -# print(path) -# create_artefact_labels_from_folder( -# path, -# do_visualize=False, -# threshold_artefact_brightness_percent=20, -# threshold_artefact_size_percent=1, -# contrast_power=20, -# ) +if __name__ == "__main__": + + repo_path = Path(__file__).resolve().parents[1] + print(f"REPO PATH : {repo_path}") + paths = [ + "dataset_clean/cropped_visual/train", + "dataset_clean/cropped_visual/val", + "dataset_clean/somatomotor", + "dataset_clean/visual_tif", + ] + for data_path in paths: + path = str(repo_path / data_path) + print(path) + create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=20, + threshold_artefact_size_percent=1, + contrast_power=20, + ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2f079d09..f94327e2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,24 +1,19 @@ -import threading -import time -import warnings -from functools import partial -from pathlib import Path - -import napari import numpy as np -import scipy.ndimage as ndimage -from napari.qt.threading import thread_worker from tifffile import imread from tifffile import imwrite +import scipy.ndimage as ndimage +import napari +from pathlib import Path +import time +import warnings +from napari.qt.threading import thread_worker from tqdm import tqdm - -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed - +import threading # import sys # sys.path.append(str(Path(__file__) / "../../")) - +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -38,9 +33,7 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm( - range(len(unique_label)), desc="relabeling", ncols=100 - ): + for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): i = unique_label[i_label] if i == 0: continue @@ -88,16 +81,13 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact, test=False): +def ask_labels(unique_artefact): global returns returns = [] - if not test: - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] - else: - i_labels_to_add_tmp = [0] + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] if i_labels_to_add_tmp == [0]: print("no label added") @@ -140,15 +130,7 @@ def ask_labels(unique_artefact, test=False): print("close the napari window to continue") -def relabel( - image_path, - label_path, - go_fast=False, - check_for_unicity=True, - delay=0.3, - viewer=None, - test=False, -): +def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -162,8 +144,6 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 - viewer : napari.Viewer, optional - the napari viewer, by default None """ global returns @@ -178,10 +158,7 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - if not test: - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -201,52 +178,30 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread( - target=partial(ask_labels, test=test), args=(unique_artefact,) - ) + t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add), 0, artefact - ) - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer - viewer.add_image(image, name="image") + artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - if not test: - napari.run() + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where( - np.isin(artefact, i_labels_to_add_tmp), artefact, 0 - ) + artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) print("these labels will be added") - if test: - viewer.close() - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer - if not test: - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") - if test: - revert = "n" - viewer.close() + viewer = napari.view_image(image) + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) - if test: - break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -303,16 +258,12 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget( - old_label, new_label, map_labels_existing, delay=0.5 -): +def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect( - lambda arg: modify_viewer(old_label, new_label, arg) - ) + worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -329,12 +280,8 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] - ) - new_label.colormap.colors = np.array( - [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] - ) + old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) + new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -343,9 +290,7 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget( - old_label, new_label, map_labels_existing, delay=delay - ) + create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) napari.run() @@ -362,14 +307,14 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, - str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), + label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") -# -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +if __name__ == "__main__": + + im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") + image_path = str(im_path / "image.tif") + gt_labels_path = str(im_path / "labels.tif") + + relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index ee9919b6..857bcd19 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,20 +1,74 @@ -import napari import numpy as np import pandas as pd from tqdm import tqdm +import napari from napari_cellseg3d.utils import LOGGER as log +def map_labels(labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > 0.5: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + if ratio_pixel_found > 0.8: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + # if total_pixel_found > np.sum(counts): + # raise ValueError( + # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" + # ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance( - labels, - model_labels, - threshold_correct=PERCENT_CORRECT, - print_details=False, - visualize=False, -): +def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): """Evaluate the model performance. Parameters ---------- @@ -22,10 +76,8 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - print_details : bool + do_print : bool If True, print the results. - visualize : bool - If True, visualize the results. Returns ------- neuron_found : float @@ -49,7 +101,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels, threshold_correct + labels, model_labels ) # calculate the number of neurons individually found @@ -67,9 +119,7 @@ def evaluate_model_performance( artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean( - [i[3] for i in map_labels_existing] - ) + mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -78,9 +128,7 @@ def evaluate_model_performance( if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean( - [i[2] for i in map_fused_neurons] - ) + mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -95,37 +143,27 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info( - f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" - ) - log.info( - f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" - ) - - if print_details: - log.info(f"Neurons found: {neurons_found}") - log.info(f"Neurons fused: {neurons_fused}") - log.info(f"Neurons not found: {neurons_not_found}") - log.info(f"Artefacts found: {artefacts_found}") - log.info( - f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + if do_print: + print("Neurons found: ", neurons_found) + print("Neurons fused: ", neurons_fused) + print("Neurons not found: ", neurons_not_found) + print("Artefacts found: ", artefacts_found) + print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) + print( + "Mean ratio of the neurons pixels correctly labelled: ", + mean_ratio_pixel_found, ) - log.info( - f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" + print( + "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + mean_ratio_pixel_found_fused, ) - log.info( - f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" + print( + "Mean true positive ratio of the model for fused neurons: ", + mean_true_positive_ratio_model_fused, ) - log.info( - f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" + print( + "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact ) - log.info( - f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" - ) - if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -141,21 +179,15 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) is False, - unique_labels, - 0, + np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where( - np.isin(labels, neurones_not_found_labels), labels, 0 - ) + ] + not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), - model_labels, - 0, + np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -183,81 +215,6 @@ def evaluate_model_performance( ) -def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > threshold_correct: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > threshold_correct: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels - - def save_as_csv(results, path): """ Save the results as a csv file @@ -269,7 +226,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - log.debug(np.array(results).shape) + print(np.array(results).shape) df = pd.DataFrame( [results], columns=[ @@ -287,193 +244,6 @@ def save_as_csv(results, path): df.to_csv(path, index=False) -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 40412282..b68ab83e 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,47 +4,426 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": true + "pycharm": { + "is_executing": true + }, + "tags": [] }, "outputs": [], "source": [ + "import napari\n", "import numpy as np\n", + "from pathlib import Path\n", "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n" + ] + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1/5,1,1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", + "viewer.add_labels(gt_labels_resized, name='gt')\n", + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 124\n", + "Neurons fused: 0\n", + "Neurons not found: 0\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled: 1.0\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", + "Mean true positive ratio of the model for fused neurons: nan\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized)\n", + "viewer.add_labels(connected,name='connected')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 45\n", + "Neurons fused: 38\n", + "Neurons not found: 41\n", + "Artefacts found: 8\n", + "Mean true positive ratio of the model: 0.8424215218790255\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", + "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", + "Mean ratio of false pixel in artefacts: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neurons found: 47\n", + "Neurons fused: 37\n", + "Neurons not found: 40\n", + "Artefacts found: 0\n", + "Mean true positive ratio of the model: 0.8426909426266451\n", + "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", + "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", + "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", + "Mean ratio of false pixel in artefacts: nan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": true + } + }, "outputs": [], - "source": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" + ] + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { - "name": "#%%\n" + "is_executing": true } - } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} From fc957e240af0cd1e5c5bf806497f5771364c8c94 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 17 Mar 2023 16:23:26 +0100 Subject: [PATCH 479/577] Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> --- .../dev_scripts/evaluate_labels.py | 22 +- notebooks/assess_instance.ipynb | 408 ++++++++++++------ 2 files changed, 301 insertions(+), 129 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 857bcd19..b4436ccb 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -4,6 +4,7 @@ import napari from napari_cellseg3d.utils import LOGGER as log + def map_labels(labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -33,10 +34,12 @@ def map_labels(labels, model_labels): unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 + + print(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - log.debug(f"unique: {unique[ii]}") + print(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -50,8 +53,7 @@ def map_labels(labels, model_labels): tmp_map.append( [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] ) - if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + if len(tmp_map) == 1: # map to only one true neuron -> found neuron @@ -59,12 +61,14 @@ def map_labels(labels, model_labels): elif len(tmp_map) > 1: # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): - # if total_pixel_found > np.sum(counts): - # raise ValueError( - # f"total_pixel_found > np.sum(counts[ii]) : {total_pixel_found} > {np.sum(counts)}" - # ) + if total_pixel_found > np.sum(counts): + raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map + + # print(f"map_labels_existing: {map_labels_existing}") + print(f"map_fused_neurons: {map_fused_neurons}") + # print(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels @@ -99,7 +103,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - log.debug("Mapping labels...") + print("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -109,7 +113,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - log.debug("Calculating the number of neurons not found...") + print("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b68ab83e..6e6a9b5f 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -111,17 +111,274 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mapping labels...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2926.92it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "i: 1\n", + "unique: 1\n", + "i: 2\n", + "unique: 2\n", + "i: 3\n", + "unique: 3\n", + "i: 4\n", + "unique: 4\n", + "i: 5\n", + "unique: 5\n", + "i: 6\n", + "unique: 6\n", + "i: 7\n", + "unique: 7\n", + "i: 8\n", + "unique: 8\n", + "i: 9\n", + "unique: 9\n", + "i: 10\n", + "unique: 10\n", + "i: 11\n", + "unique: 11\n", + "i: 12\n", + "unique: 12\n", + "i: 13\n", + "unique: 13\n", + "i: 14\n", + "unique: 14\n", + "i: 15\n", + "unique: 15\n", + "i: 16\n", + "unique: 16\n", + "i: 17\n", + "unique: 17\n", + "i: 18\n", + "unique: 18\n", + "i: 19\n", + "unique: 19\n", + "i: 20\n", + "unique: 20\n", + "i: 21\n", + "unique: 21\n", + "i: 22\n", + "unique: 22\n", + "i: 23\n", + "unique: 23\n", + "i: 24\n", + "unique: 24\n", + "i: 25\n", + "unique: 25\n", + "i: 26\n", + "unique: 26\n", + "i: 27\n", + "unique: 27\n", + "i: 28\n", + "unique: 28\n", + "i: 29\n", + "unique: 29\n", + "i: 30\n", + "unique: 30\n", + "i: 31\n", + "unique: 31\n", + "i: 32\n", + "unique: 32\n", + "i: 33\n", + "unique: 33\n", + "i: 34\n", + "unique: 34\n", + "i: 35\n", + "unique: 35\n", + "i: 36\n", + "unique: 36\n", + "i: 37\n", + "unique: 37\n", + "i: 38\n", + "unique: 38\n", + "i: 39\n", + "unique: 39\n", + "i: 40\n", + "unique: 40\n", + "i: 41\n", + "unique: 41\n", + "i: 42\n", + "unique: 42\n", + "i: 43\n", + "unique: 43\n", + "i: 44\n", + "unique: 44\n", + "i: 45\n", + "unique: 45\n", + "i: 46\n", + "unique: 46\n", + "i: 47\n", + "unique: 47\n", + "i: 48\n", + "unique: 48\n", + "i: 49\n", + "unique: 49\n", + "i: 50\n", + "unique: 50\n", + "i: 51\n", + "unique: 51\n", + "i: 52\n", + "unique: 52\n", + "i: 53\n", + "unique: 53\n", + "i: 54\n", + "unique: 54\n", + "i: 55\n", + "unique: 55\n", + "i: 56\n", + "unique: 56\n", + "i: 57\n", + "unique: 57\n", + "i: 58\n", + "unique: 58\n", + "i: 59\n", + "unique: 59\n", + "i: 60\n", + "unique: 60\n", + "i: 61\n", + "unique: 61\n", + "i: 62\n", + "unique: 62\n", + "i: 63\n", + "unique: 63\n", + "i: 64\n", + "unique: 64\n", + "i: 65\n", + "unique: 65\n", + "i: 66\n", + "unique: 66\n", + "i: 67\n", + "unique: 67\n", + "i: 68\n", + "unique: 68\n", + "i: 69\n", + "unique: 69\n", + "i: 70\n", + "unique: 70\n", + "i: 71\n", + "unique: 71\n", + "i: 72\n", + "unique: 72\n", + "i: 73\n", + "unique: 73\n", + "i: 74\n", + "unique: 74\n", + "i: 75\n", + "unique: 75\n", + "i: 76\n", + "unique: 76\n", + "i: 77\n", + "unique: 77\n", + "i: 78\n", + "unique: 78\n", + "i: 79\n", + "unique: 79\n", + "i: 80\n", + "unique: 80\n", + "i: 81\n", + "unique: 81\n", + "i: 82\n", + "unique: 82\n", + "i: 83\n", + "unique: 83\n", + "i: 84\n", + "unique: 84\n", + "i: 85\n", + "unique: 85\n", + "i: 86\n", + "unique: 86\n", + "i: 87\n", + "unique: 87\n", + "i: 88\n", + "unique: 88\n", + "i: 89\n", + "unique: 89\n", + "i: 90\n", + "unique: 90\n", + "i: 91\n", + "unique: 91\n", + "i: 93\n", + "unique: 93\n", + "i: 94\n", + "unique: 94\n", + "i: 95\n", + "unique: 95\n", + "i: 96\n", + "unique: 96\n", + "i: 97\n", + "unique: 97\n", + "i: 98\n", + "unique: 98\n", + "i: 99\n", + "unique: 99\n", + "i: 100\n", + "unique: 100\n", + "i: 101\n", + "unique: 101\n", + "i: 102\n", + "unique: 102\n", + "i: 103\n", + "unique: 103\n", + "i: 104\n", + "unique: 104\n", + "i: 105\n", + "unique: 105\n", + "i: 106\n", + "unique: 106\n", + "i: 107\n", + "unique: 107\n", + "i: 108\n", + "unique: 108\n", + "i: 109\n", + "unique: 109\n", + "i: 110\n", + "unique: 110\n", + "i: 111\n", + "unique: 111\n", + "i: 112\n", + "unique: 112\n", + "i: 113\n", + "unique: 113\n", + "i: 114\n", + "unique: 114\n", + "i: 115\n", + "unique: 115\n", + "i: 116\n", + "unique: 116\n", + "i: 117\n", + "unique: 117\n", + "i: 118\n", + "unique: 118\n", + "i: 119\n", + "unique: 119\n", + "i: 120\n", + "unique: 120\n", + "i: 121\n", + "unique: 121\n", + "i: 122\n", + "unique: 122\n", + "i: 123\n", + "unique: 123\n", + "i: 124\n", + "unique: 124\n", + "i: 125\n", + "unique: 125\n", + "map_fused_neurons: []\n", + "Calculating the number of neurons not found...\n", "Neurons found: 124\n", "Neurons fused: 0\n", "Neurons not found: 0\n", @@ -157,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -168,145 +425,66 @@ { "data": { "text/plain": [ - "" + "dtype('int32')" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')" + "viewer.add_labels(connected,name='connected')\n", + "connected.dtype" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 1912.60it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 45\n", - "Neurons fused: 38\n", - "Neurons not found: 41\n", - "Artefacts found: 8\n", - "Mean true positive ratio of the model: 0.8424215218790255\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9295584641244191\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9332428837640889\n", - "Mean true positive ratio of the model for fused neurons: 0.8300682812799907\n", - "Mean ratio of false pixel in artefacts: 1.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2290.34it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Neurons found: 47\n", - "Neurons fused: 37\n", - "Neurons not found: 40\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 0.8426909426266451\n", - "Mean ratio of the neurons pixels correctly labelled: 0.9355733839977192\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: 0.9369165405470844\n", - "Mean true positive ratio of the model for fused neurons: 0.8198206611354032\n", - "Mean ratio of false pixel in artefacts: nan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -320,7 +498,7 @@ { "data": { "text/plain": [ - "(25, 64, 64)" + "dtype('int64')" ] }, "execution_count": 9, @@ -329,14 +507,12 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", - "viewer.add_labels(voronoi)\n", - "voronoi.shape" + "gt_labels_resized.dtype" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -353,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "collapsed": false, "jupyter": { @@ -374,15 +550,7 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 28%|██████████████████████▍ | 34/123 [07:33<22:45, 15.34s/it]" - ] - } - ], + "outputs": [], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] From 2bafffc3462516f5f3eb1f2da5bac67292129f71 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:05:19 +0100 Subject: [PATCH 480/577] Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black --- .../dev_scripts/artefact_labeling.py | 33 +- .../dev_scripts/correct_labels.py | 45 +- .../dev_scripts/evaluate_labels.py | 282 +++++++-- notebooks/assess_instance.ipynb | 553 ++++++++---------- 4 files changed, 563 insertions(+), 350 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 875ca9b6..b66ace64 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -5,6 +5,7 @@ import scipy.ndimage as ndimage import os import napari + # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -44,7 +45,9 @@ def map_labels(labels, artefacts): unique = np.flip(unique[np.argsort(counts)]) counts = np.flip(counts[np.argsort(counts)]) if unique[0] != 0: - map_labels_existing.append(np.array([i, unique[np.argmax(counts)]])) + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) elif ( counts[0] < np.sum(counts) * 2 / 3.0 ): # the artefact is connected to multiple neurons @@ -100,14 +103,18 @@ def make_labels( image_contrasted = np.where(image > threshold_brightness, image, 0) if use_watershed: - image_contrasted= (image_contrasted - np.min(image_contrasted)) / (np.max(image_contrasted) - np.min(image_contrasted)) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) else: labels = ndimage.label(image_contrasted)[0] - labels = select_artefacts_by_size(labels, min_size=threshold_size, is_labeled=True) + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) if not do_multi_label: labels = np.where(labels > 0, label_value, 0) @@ -119,7 +126,9 @@ def make_labels( ) -def select_image_by_labels(path_image, path_labels, path_image_out, label_values): +def select_image_by_labels( + path_image, path_labels, path_image_out, label_values +): """Select image by labels. Parameters ---------- @@ -213,7 +222,9 @@ def make_artefact_labels( # calculate the percentile of the intensity of all the pixels that are labeled as neurons # check if the neurons are not empty if np.sum(neurons) > 0: - threshold = np.percentile(image[neurons], threshold_artefact_brightness_percent) + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) else: # take the percentile of the non neurons if the neurons are empty threshold = np.percentile(image[non_neurons], 90) @@ -244,7 +255,9 @@ def make_artefact_labels( # calculate the percentile of the size of the neurons if np.sum(neurons) > 0: sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) - neurone_size_percentile = np.percentile(sizes, threshold_artefact_size_percent) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) else: # find the size of each connected component sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) @@ -370,8 +383,12 @@ def create_artefact_labels_from_folder( Power for contrast enhancement. """ # find all the images in the folder and create a list - path_labels = [f for f in os.listdir(path + "/labels") if f.endswith(".tif")] - path_images = [f for f in os.listdir(path + "/volumes") if f.endswith(".tif")] + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] # sort the list path_labels.sort() path_images.sort() diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index f94327e2..da938c01 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -9,11 +9,13 @@ from napari.qt.threading import thread_worker from tqdm import tqdm import threading + # import sys # sys.path.append(str(Path(__file__) / "../../")) from napari_cellseg3d.code_models.model_instance_seg import binary_watershed import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -33,7 +35,9 @@ def relabel_non_unique_i(label, save_path, go_fast=False): new_labels = np.zeros_like(label) map_labels_existing = [] unique_label = np.unique(label) - for i_label in tqdm(range(len(unique_label)), desc="relabeling", ncols=100): + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): i = unique_label[i_label] if i == 0: continue @@ -130,7 +134,9 @@ def ask_labels(unique_artefact): print("close the napari window to continue") -def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3): +def relabel( + image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 +): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters ---------- @@ -158,7 +164,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map(map_labels_existing, label_path, new_label_path, delay=delay) + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -180,7 +188,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay # visualize the artefact and ask the user which label to add to the label image t = threading.Thread(target=ask_labels, args=(unique_artefact,)) t.start() - artefact_copy = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") @@ -191,7 +201,9 @@ def relabel(image_path, label_path, go_fast=False, check_for_unicity=True, delay for i in i_labels_to_add: if i not in i_labels_to_add_tmp: i_labels_to_add_tmp.append(i) - artefact_copy = np.where(np.isin(artefact, i_labels_to_add_tmp), artefact, 0) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) print("these labels will be added") viewer = napari.view_image(image) viewer.add_labels(artefact_copy, name="labels added") @@ -258,12 +270,16 @@ def to_show(map_labels_existing, delay=0.5): time.sleep(delay) -def create_connected_widget(old_label, new_label, map_labels_existing, delay=0.5): +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): """Builds a widget that can control a function in another thread.""" worker = to_show(map_labels_existing, delay) worker.start() - worker.yielded.connect(lambda arg: modify_viewer(old_label, new_label, arg)) + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): @@ -280,8 +296,12 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label = viewer.add_labels(label, num_colors=3) new_label = viewer.add_labels(relabel, num_colors=3) - old_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0],[1.0, 1.0, 1.0, 1.0]]) - new_label.colormap.colors = np.array([[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0],[1.0, 0.0, 0.0, 1.0]]) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) # viewer.dims.ndisplay = 3 viewer.camera.angles = (180, 3, 50) @@ -290,7 +310,9 @@ def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): old_label.show_selected_label = True new_label.show_selected_label = True - create_connected_widget(old_label, new_label, map_labels_existing, delay=delay) + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) napari.run() @@ -307,7 +329,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if file.suffix == ".tif": label = imread(str(Path(folder_path / file))) relabel_non_unique_i( - label, str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")) + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), ) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index b4436ccb..cf8cfdda 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,15 +1,55 @@ import numpy as np +from collections import Counter +from dataclasses import dataclass import pandas as pd from tqdm import tqdm +from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -def map_labels(labels, model_labels): +PERCENT_CORRECT = 0.7 + +@dataclass +class LabelInfo: + gt_index: int + model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) + best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + overall_gt_label_coverage: float = 0.0 # true positive ration of the model + + def get_correct_ratio(self): + for model_label, status in self.model_labels_id_and_status.items(): + if status == "correct": + return self.best_model_label_coverage + else: + return None + +def eval_model(gt_labels, model_labels, print_report=False): + + report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + + per_label_perfs = [] + for report in report_list: + if print_report: + log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") + log.info(f"Best model label coverage : {report.best_model_label_coverage}") + log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + + perf = report.get_correct_ratio() + if perf is not None: + per_label_perfs.append(perf) + + per_label_perfs = np.array(per_label_perfs) + return per_label_perfs.mean(), new_labels, fused_labels + + + + +def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters ---------- - labels : ndarray + gt_labels : ndarray Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. @@ -22,6 +62,147 @@ def map_labels(labels, model_labels): new_labels: list The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ + + + map_labels_existing = [] + map_fused_neurons = {} + "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" + background_labels = model_labels[np.where((gt_labels == 0))] + "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" + new_labels = [] + for lab in np.unique(background_labels): + if lab == 0: + continue + gt_background_size_at_lab = ( + gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] + .flatten() + .shape[0] + ) + gt_lab_size = ( + gt_labels[np.where(model_labels == lab)].flatten().shape[0] + ) + if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: + new_labels.append(lab) + + label_report_list = [] + # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label + # model_label_values = {} # contains the model labels value assigned to each unique gt label + not_found_id = 0 + + for i in tqdm(np.unique(gt_labels)): + if i == 0: + continue + + gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label + + model_lab_on_gt = model_labels[ + np.where(((gt_labels == i) & (model_labels != 0))) + ] # all models labels on single gt_label + info = LabelInfo(i) + + info.model_labels_id_and_status = { + label_id: "" for label_id in np.unique(model_lab_on_gt) + } + + if model_lab_on_gt.shape[0] == 0: + info.model_labels_id_and_status[ + f"not_found_{not_found_id}" + ] = "not found" + not_found_id += 1 + label_report_list.append(info) + continue + + log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") + + # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label + log.debug( + f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" + ) + + ratio = [] + for model_lab_id in info.model_labels_id_and_status.keys(): + size_model_label = ( + model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] + .flatten() + .shape[0] + ) + size_gt_label = gt_label.flatten().shape[0] + + log.debug(f"size_model_label : {size_model_label}") + log.debug(f"size_gt_label : {size_gt_label}") + + ratio.append(size_model_label / size_gt_label) + + # log.debug(ratio) + ratio_model_lab_for_given_gt_lab = np.array(ratio) + info.best_model_label_coverage = ( + ratio_model_lab_for_given_gt_lab.max() + ) + + best_model_lab_id = model_lab_on_gt[ + np.argmax(ratio_model_lab_for_given_gt_lab) + ] + log.debug(f"best_model_lab_id : {best_model_lab_id}") + + info.overall_gt_label_coverage = ( + ratio_model_lab_for_given_gt_lab.sum() + ) # the ratio of the pixels of the true label correctly labelled + + if info.best_model_label_coverage > PERCENT_CORRECT: + info.model_labels_id_and_status[best_model_lab_id] = "correct" + # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] + else: + info.model_labels_id_and_status[best_model_lab_id] = "wrong" + for model_lab_id in np.unique(model_lab_on_gt): + if model_lab_id != best_model_lab_id: + log.debug(model_lab_id, "is wrong") + info.model_labels_id_and_status[model_lab_id] = "wrong" + + label_report_list.append(info) + + correct_labels_id = [] + for report in label_report_list: + for i_lab in report.model_labels_id_and_status.keys(): + if report.model_labels_id_and_status[i_lab] == "correct": + correct_labels_id.append(i_lab) + """Find all labels in label_report_list that are correct more than once""" + duplicated_labels = [ + item for item, count in Counter(correct_labels_id).items() if count > 1 + ] + "Sum up the size of all duplicated labels" + for i in duplicated_labels: + for report in label_report_list: + if ( + i in report.model_labels_id_and_status.keys() + and report.model_labels_id_and_status[i] == "correct" + ): + size = ( + model_labels[np.where(model_labels == i)] + .flatten() + .shape[0] + ) + map_fused_neurons[i] = size + + return label_report_list, new_labels, map_fused_neurons + + +def map_labels(gt_labels, model_labels): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ map_labels_existing = [] map_fused_neurons = [] new_labels = [] @@ -29,17 +210,17 @@ def map_labels(labels, model_labels): for i in tqdm(np.unique(model_labels)): if i == 0: continue - indexes = labels[model_labels == i] + indexes = gt_labels[model_labels == i] # find the most common labels in the label i of the model unique, counts = np.unique(indexes, return_counts=True) tmp_map = [] total_pixel_found = 0 - print(f"i: {i}") + # log.debug(f"i: {i}") for ii in range(len(unique)): true_positive_ratio_model = counts[ii] / np.sum(counts) # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - print(f"unique: {unique[ii]}") + # log.debug(f"unique: {unique[ii]}") if unique[ii] == 0: if true_positive_ratio_model > 0.5: # -> artifact found @@ -47,14 +228,20 @@ def map_labels(labels, model_labels): else: # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum(labels == unique[ii]) + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) if ratio_pixel_found > 0.8: total_pixel_found += np.sum(counts[ii]) tmp_map.append( - [i, unique[ii], ratio_pixel_found, true_positive_ratio_model] + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] ) - if len(tmp_map) == 1: # map to only one true neuron -> found neuron map_labels_existing.append(tmp_map[0]) @@ -62,17 +249,21 @@ def map_labels(labels, model_labels): # map to multiple true neurons -> fused neuron for ii in range(len(tmp_map)): if total_pixel_found > np.sum(counts): - raise ValueError(f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}") + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) tmp_map[ii][3] = total_pixel_found / np.sum(counts) map_fused_neurons += tmp_map - # print(f"map_labels_existing: {map_labels_existing}") - print(f"map_fused_neurons: {map_fused_neurons}") - # print(f"new_labels: {new_labels}") + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") return map_labels_existing, map_fused_neurons, new_labels -def evaluate_model_performance(labels, model_labels, do_print=True, visualize=False): +def evaluate_model_performance( + labels, model_labels, do_print=False, visualize=False +): """Evaluate the model performance. Parameters ---------- @@ -82,6 +273,8 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa Label image from the model labelled as mulitple values. do_print : bool If True, print the results. + visualize : bool + If True, visualize the results. Returns ------- neuron_found : float @@ -103,7 +296,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact: float The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) """ - print("Mapping labels...") + log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( labels, model_labels ) @@ -113,7 +306,7 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa # calculate the number of neurons fused neurons_fused = len(map_fused_neurons) # calculate the number of neurons not found - print("Calculating the number of neurons not found...") + log.debug("Calculating the number of neurons not found...") neurons_found_labels = np.unique( [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] ) @@ -123,7 +316,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa artefacts_found = len(new_labels) if len(map_labels_existing) > 0: # calculate the mean true positive ratio of the model - mean_true_positive_ratio_model = np.mean([i[3] for i in map_labels_existing]) + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) # calculate the mean ratio of the neurons pixels correctly labelled mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) else: @@ -132,7 +327,9 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa if len(map_fused_neurons) > 0: # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons - mean_ratio_pixel_found_fused = np.mean([i[2] for i in map_fused_neurons]) + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) # calculate the mean true positive ratio of the model for the fused neurons mean_true_positive_ratio_model_fused = np.mean( [i[3] for i in map_fused_neurons] @@ -148,26 +345,35 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa mean_ratio_false_pixel_artefact = np.nan if do_print: - print("Neurons found: ", neurons_found) - print("Neurons fused: ", neurons_fused) - print("Neurons not found: ", neurons_not_found) - print("Artefacts found: ", artefacts_found) - print("Mean true positive ratio of the model: ", mean_true_positive_ratio_model) - print( + log.info("Neurons found: ") + log.info(neurons_found) + log.info("Neurons fused: ") + log.info(neurons_fused) + log.info("Neurons not found: ") + log.info(neurons_not_found) + log.info("Artefacts found: ") + log.info(artefacts_found) + log.info( + "Mean true positive ratio of the model: ", + ) + log.info(mean_true_positive_ratio_model) + log.info( "Mean ratio of the neurons pixels correctly labelled: ", - mean_ratio_pixel_found, ) - print( + log.info(mean_ratio_pixel_found) + log.info( "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", - mean_ratio_pixel_found_fused, ) - print( + log.info(mean_ratio_pixel_found_fused) + log.info( "Mean true positive ratio of the model for fused neurons: ", - mean_true_positive_ratio_model_fused, ) - print( - "Mean ratio of false pixel in artefacts: ", mean_ratio_false_pixel_artefact + log.info(mean_true_positive_ratio_model_fused) + log.info( + "Mean ratio of false pixel in artefacts: " ) + log.info(mean_ratio_false_pixel_artefact) + if visualize: viewer = napari.Viewer() viewer.add_labels(labels, name="ground truth") @@ -183,15 +389,21 @@ def evaluate_model_performance(labels, model_labels, do_print=True, visualize=Fa ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, unique_labels, 0 + np.isin(unique_labels, neurons_found_labels) == False, + unique_labels, + 0, ) neurones_not_found_labels = neurones_not_found_labels[ neurones_not_found_labels != 0 - ] - not_found = np.where(np.isin(labels, neurones_not_found_labels), labels, 0) + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) viewer.add_labels(not_found, name="ground truth not found") artefacts_found = np.where( - np.isin(model_labels, [i[0] for i in new_labels]), model_labels, 0 + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, ) viewer.add_labels(artefacts_found, name="model's labels artefacts") fused_model = np.where( @@ -230,7 +442,7 @@ def save_as_csv(results, path): path: str The path to save the csv file """ - print(np.array(results).shape) + log.debug(np.array(results).shape) df = pd.DataFrame( [results], columns=[ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 6e6a9b5f..d521c395 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -18,7 +18,11 @@ "\n", "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", "from napari_cellseg3d.utils import resize\n", - "from napari_cellseg3d.code_models.model_instance_seg import binary_connected, binary_watershed, voronoi_otsu" + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + ")" ] }, { @@ -45,16 +49,6 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -72,13 +66,13 @@ "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", - "zoom = (1/5,1,1)\n", + "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", - "viewer.add_image(prediction_resized, name='pred', colormap='inferno')\n", - "viewer.add_labels(gt_labels_resized, name='gt')\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", "print(prediction_resized.shape)\n", "print(gt_labels_resized.shape)" ] @@ -98,6 +92,7 @@ "outputs": [], "source": [ "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" ] }, @@ -115,279 +110,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mapping labels...\n" + "2023-03-22 14:47:30,112 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 2953.37it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "i: 1\n", - "unique: 1\n", - "i: 2\n", - "unique: 2\n", - "i: 3\n", - "unique: 3\n", - "i: 4\n", - "unique: 4\n", - "i: 5\n", - "unique: 5\n", - "i: 6\n", - "unique: 6\n", - "i: 7\n", - "unique: 7\n", - "i: 8\n", - "unique: 8\n", - "i: 9\n", - "unique: 9\n", - "i: 10\n", - "unique: 10\n", - "i: 11\n", - "unique: 11\n", - "i: 12\n", - "unique: 12\n", - "i: 13\n", - "unique: 13\n", - "i: 14\n", - "unique: 14\n", - "i: 15\n", - "unique: 15\n", - "i: 16\n", - "unique: 16\n", - "i: 17\n", - "unique: 17\n", - "i: 18\n", - "unique: 18\n", - "i: 19\n", - "unique: 19\n", - "i: 20\n", - "unique: 20\n", - "i: 21\n", - "unique: 21\n", - "i: 22\n", - "unique: 22\n", - "i: 23\n", - "unique: 23\n", - "i: 24\n", - "unique: 24\n", - "i: 25\n", - "unique: 25\n", - "i: 26\n", - "unique: 26\n", - "i: 27\n", - "unique: 27\n", - "i: 28\n", - "unique: 28\n", - "i: 29\n", - "unique: 29\n", - "i: 30\n", - "unique: 30\n", - "i: 31\n", - "unique: 31\n", - "i: 32\n", - "unique: 32\n", - "i: 33\n", - "unique: 33\n", - "i: 34\n", - "unique: 34\n", - "i: 35\n", - "unique: 35\n", - "i: 36\n", - "unique: 36\n", - "i: 37\n", - "unique: 37\n", - "i: 38\n", - "unique: 38\n", - "i: 39\n", - "unique: 39\n", - "i: 40\n", - "unique: 40\n", - "i: 41\n", - "unique: 41\n", - "i: 42\n", - "unique: 42\n", - "i: 43\n", - "unique: 43\n", - "i: 44\n", - "unique: 44\n", - "i: 45\n", - "unique: 45\n", - "i: 46\n", - "unique: 46\n", - "i: 47\n", - "unique: 47\n", - "i: 48\n", - "unique: 48\n", - "i: 49\n", - "unique: 49\n", - "i: 50\n", - "unique: 50\n", - "i: 51\n", - "unique: 51\n", - "i: 52\n", - "unique: 52\n", - "i: 53\n", - "unique: 53\n", - "i: 54\n", - "unique: 54\n", - "i: 55\n", - "unique: 55\n", - "i: 56\n", - "unique: 56\n", - "i: 57\n", - "unique: 57\n", - "i: 58\n", - "unique: 58\n", - "i: 59\n", - "unique: 59\n", - "i: 60\n", - "unique: 60\n", - "i: 61\n", - "unique: 61\n", - "i: 62\n", - "unique: 62\n", - "i: 63\n", - "unique: 63\n", - "i: 64\n", - "unique: 64\n", - "i: 65\n", - "unique: 65\n", - "i: 66\n", - "unique: 66\n", - "i: 67\n", - "unique: 67\n", - "i: 68\n", - "unique: 68\n", - "i: 69\n", - "unique: 69\n", - "i: 70\n", - "unique: 70\n", - "i: 71\n", - "unique: 71\n", - "i: 72\n", - "unique: 72\n", - "i: 73\n", - "unique: 73\n", - "i: 74\n", - "unique: 74\n", - "i: 75\n", - "unique: 75\n", - "i: 76\n", - "unique: 76\n", - "i: 77\n", - "unique: 77\n", - "i: 78\n", - "unique: 78\n", - "i: 79\n", - "unique: 79\n", - "i: 80\n", - "unique: 80\n", - "i: 81\n", - "unique: 81\n", - "i: 82\n", - "unique: 82\n", - "i: 83\n", - "unique: 83\n", - "i: 84\n", - "unique: 84\n", - "i: 85\n", - "unique: 85\n", - "i: 86\n", - "unique: 86\n", - "i: 87\n", - "unique: 87\n", - "i: 88\n", - "unique: 88\n", - "i: 89\n", - "unique: 89\n", - "i: 90\n", - "unique: 90\n", - "i: 91\n", - "unique: 91\n", - "i: 93\n", - "unique: 93\n", - "i: 94\n", - "unique: 94\n", - "i: 95\n", - "unique: 95\n", - "i: 96\n", - "unique: 96\n", - "i: 97\n", - "unique: 97\n", - "i: 98\n", - "unique: 98\n", - "i: 99\n", - "unique: 99\n", - "i: 100\n", - "unique: 100\n", - "i: 101\n", - "unique: 101\n", - "i: 102\n", - "unique: 102\n", - "i: 103\n", - "unique: 103\n", - "i: 104\n", - "unique: 104\n", - "i: 105\n", - "unique: 105\n", - "i: 106\n", - "unique: 106\n", - "i: 107\n", - "unique: 107\n", - "i: 108\n", - "unique: 108\n", - "i: 109\n", - "unique: 109\n", - "i: 110\n", - "unique: 110\n", - "i: 111\n", - "unique: 111\n", - "i: 112\n", - "unique: 112\n", - "i: 113\n", - "unique: 113\n", - "i: 114\n", - "unique: 114\n", - "i: 115\n", - "unique: 115\n", - "i: 116\n", - "unique: 116\n", - "i: 117\n", - "unique: 117\n", - "i: 118\n", - "unique: 118\n", - "i: 119\n", - "unique: 119\n", - "i: 120\n", - "unique: 120\n", - "i: 121\n", - "unique: 121\n", - "i: 122\n", - "unique: 122\n", - "i: 123\n", - "unique: 123\n", - "i: 124\n", - "unique: 124\n", - "i: 125\n", - "unique: 125\n", - "map_fused_neurons: []\n", - "Calculating the number of neurons not found...\n", - "Neurons found: 124\n", - "Neurons fused: 0\n", - "Neurons not found: 0\n", - "Artefacts found: 0\n", - "Mean true positive ratio of the model: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled: 1.0\n", - "Mean ratio of the neurons pixels correctly labelled for fused neurons: nan\n", - "Mean true positive ratio of the model for fused neurons: nan\n", - "Mean ratio of false pixel in artefacts: nan\n" + "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" ] }, { @@ -414,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -428,66 +165,177 @@ "dtype('int32')" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected,name='connected')\n", + "viewer.add_labels(connected, name=\"connected\")\n", "connected.dtype" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,231 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(45,\n", + " 38,\n", + " 41,\n", + " 8,\n", + " 0.8424215218790255,\n", + " 0.9295584641244191,\n", + " 0.9332428837640889,\n", + " 0.8300682812799907,\n", + " 1.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, connected)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,344 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(47,\n", + " 37,\n", + " 40,\n", + " 0,\n", + " 0.8426909426266451,\n", + " 0.9355733839977192,\n", + " 0.9369165405470844,\n", + " 0.8198206611354032,\n", + " nan)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "watershed = binary_watershed(prediction_resized, thres_small=20, rem_seed_thres=5)\n", + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=20, rem_seed_thres=5\n", + ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 10)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -501,7 +349,7 @@ "dtype('int64')" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -512,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -522,42 +370,155 @@ "is_executing": true } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", + " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", + " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", + " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", + " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", + " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", + " 122], dtype=uint32),\n", + " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", + " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", + " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", + " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", + " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", + " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", + " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", + " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", + " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", + " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", + " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", + " 28, 36, 28, 14, 31, 54], dtype=int64))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(voronoi, return_counts=True)" + "np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", + " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", + " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", + " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", + " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", + " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", + " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", + " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", + " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", + " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", + " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", + " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", + " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", + " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", + " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", + " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", + " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", + " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", + " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", + " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", + " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", + " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", + " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", + " 33, 25, 7, 5, 7, 19, 32, 40],\n", + " dtype=int64))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# np.unique(gt_labels, return_counts=True)" + "np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,755 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(72,\n", + " 8,\n", + " 44,\n", + " 1,\n", + " 0.8348479609766444,\n", + " 0.9314226186350036,\n", + " 0.9483750072126669,\n", + " 0.8528417100412058,\n", + " 1.0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { From fd11680ad42e82690fe70dd5a30e77bc2326d5f3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:08:05 +0100 Subject: [PATCH 481/577] black --- .../code_models/instance_segmentation.py | 21 ++++++++---- napari_cellseg3d/code_models/workers.py | 4 ++- .../code_plugins/plugin_model_inference.py | 8 +++-- napari_cellseg3d/config.py | 2 ++ .../dev_scripts/evaluate_labels.py | 33 +++++++++++-------- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index c65a2282..f8fc2517 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -36,7 +36,7 @@ def __init__( function: callable, num_sliders: int, num_counters: int, - widget_parent: QWidget = None + widget_parent: QWidget = None, ): """ Methods for instance segmentation @@ -59,7 +59,14 @@ def __init__( setattr( self, widget, - ui.Slider(0, 100, 1, divide_factor=100, text_label="", parent=None), + ui.Slider( + 0, + 100, + 1, + divide_factor=100, + text_label="", + parent=None, + ), ) self.sliders.append(getattr(self, widget)) @@ -400,13 +407,13 @@ def sphericity(region): class Watershed(InstanceMethod): """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=WATERSHED, function=binary_watershed, num_sliders=2, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -446,13 +453,13 @@ def run_method(self, image): class ConnectedComponents(InstanceMethod): """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" - def __init__(self, widget_parent = None): + def __init__(self, widget_parent=None): super().__init__( name=CONNECTED_COMP, function=binary_connected, num_sliders=1, num_counters=1, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.sliders[0].label.setText("Foreground probability threshold") @@ -483,7 +490,7 @@ def __init__(self, widget_parent): function=voronoi_otsu, num_sliders=0, num_counters=2, - widget_parent=widget_parent + widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness self.counters[ diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index b94c8d23..2de03dd8 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -643,7 +643,9 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) # TODO(cyril) check if correct + instance_labels = np.swapaxes( + instance_labels, 0, 2 + ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 201452ac..e4a5af02 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -608,7 +608,9 @@ def start(self): self.instance_config = config.InstanceSegConfig( enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[self.instance_widgets.method_choice.currentText()] + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], ) self.post_process_config = config.PostProcessConfig( @@ -877,7 +879,9 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method_name = self.worker_config.post_process_config.instance.method.name + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) number_cells = ( np.unique(labels.flatten()).size - 1 diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index af5e8c3b..e1dd57ef 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -122,11 +122,13 @@ class Zoom: enabled: bool = True zoom_values: List[float] = None + @dataclass class InstanceSegConfig: enabled: bool = False method: InstanceMethod = None + @dataclass class PostProcessConfig: """Class to record params for post processing diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index cf8cfdda..1aa52932 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -10,11 +10,14 @@ PERCENT_CORRECT = 0.7 + @dataclass class LabelInfo: gt_index: int model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = 0.0 # ratio of pixels of the gt label correctly labelled + best_model_label_coverage: float = ( + 0.0 # ratio of pixels of the gt label correctly labelled + ) overall_gt_label_coverage: float = 0.0 # true positive ration of the model def get_correct_ratio(self): @@ -24,16 +27,25 @@ def get_correct_ratio(self): else: return None + def eval_model(gt_labels, model_labels, print_report=False): - report_list, new_labels, fused_labels = create_label_report(gt_labels, model_labels) + report_list, new_labels, fused_labels = create_label_report( + gt_labels, model_labels + ) per_label_perfs = [] for report in report_list: if print_report: - log.info(f"Label {report.gt_index} : {report.model_labels_id_and_status}") - log.info(f"Best model label coverage : {report.best_model_label_coverage}") - log.info(f"Overall gt label coverage : {report.overall_gt_label_coverage}") + log.info( + f"Label {report.gt_index} : {report.model_labels_id_and_status}" + ) + log.info( + f"Best model label coverage : {report.best_model_label_coverage}" + ) + log.info( + f"Overall gt label coverage : {report.overall_gt_label_coverage}" + ) perf = report.get_correct_ratio() if perf is not None: @@ -43,8 +55,6 @@ def eval_model(gt_labels, model_labels, print_report=False): return per_label_perfs.mean(), new_labels, fused_labels - - def create_label_report(gt_labels, model_labels): """Map the model's labels to the neurons labels. Parameters @@ -63,7 +73,6 @@ def create_label_report(gt_labels, model_labels): The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact """ - map_labels_existing = [] map_fused_neurons = {} "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" @@ -135,9 +144,7 @@ def create_label_report(gt_labels, model_labels): # log.debug(ratio) ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ( - ratio_model_lab_for_given_gt_lab.max() - ) + info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() best_model_lab_id = model_lab_on_gt[ np.argmax(ratio_model_lab_for_given_gt_lab) @@ -369,9 +376,7 @@ def evaluate_model_performance( "Mean true positive ratio of the model for fused neurons: ", ) log.info(mean_true_positive_ratio_model_fused) - log.info( - "Mean ratio of false pixel in artefacts: " - ) + log.info("Mean ratio of false pixel in artefacts: ") log.info(mean_ratio_false_pixel_artefact) if visualize: From 1f95455941d8458dd9fd3e8a27b37c9cb0b79fd6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 15:49:45 +0100 Subject: [PATCH 482/577] Complete instance method evaluation --- .../dev_scripts/evaluate_labels.py | 564 +++++++++--------- notebooks/assess_instance.ipynb | 290 ++++----- 2 files changed, 385 insertions(+), 469 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 1aa52932..3082e79f 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,275 +1,15 @@ import numpy as np -from collections import Counter -from dataclasses import dataclass import pandas as pd from tqdm import tqdm -from typing import Dict import napari from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.7 - - -@dataclass -class LabelInfo: - gt_index: int - model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) - best_model_label_coverage: float = ( - 0.0 # ratio of pixels of the gt label correctly labelled - ) - overall_gt_label_coverage: float = 0.0 # true positive ration of the model - - def get_correct_ratio(self): - for model_label, status in self.model_labels_id_and_status.items(): - if status == "correct": - return self.best_model_label_coverage - else: - return None - - -def eval_model(gt_labels, model_labels, print_report=False): - - report_list, new_labels, fused_labels = create_label_report( - gt_labels, model_labels - ) - - per_label_perfs = [] - for report in report_list: - if print_report: - log.info( - f"Label {report.gt_index} : {report.model_labels_id_and_status}" - ) - log.info( - f"Best model label coverage : {report.best_model_label_coverage}" - ) - log.info( - f"Overall gt label coverage : {report.overall_gt_label_coverage}" - ) - - perf = report.get_correct_ratio() - if perf is not None: - per_label_perfs.append(perf) - - per_label_perfs = np.array(per_label_perfs) - return per_label_perfs.mean(), new_labels, fused_labels - - -def create_label_report(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - - map_labels_existing = [] - map_fused_neurons = {} - "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" - background_labels = model_labels[np.where((gt_labels == 0))] - "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" - new_labels = [] - for lab in np.unique(background_labels): - if lab == 0: - continue - gt_background_size_at_lab = ( - gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] - .flatten() - .shape[0] - ) - gt_lab_size = ( - gt_labels[np.where(model_labels == lab)].flatten().shape[0] - ) - if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: - new_labels.append(lab) - - label_report_list = [] - # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label - # model_label_values = {} # contains the model labels value assigned to each unique gt label - not_found_id = 0 - - for i in tqdm(np.unique(gt_labels)): - if i == 0: - continue - - gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label - - model_lab_on_gt = model_labels[ - np.where(((gt_labels == i) & (model_labels != 0))) - ] # all models labels on single gt_label - info = LabelInfo(i) - - info.model_labels_id_and_status = { - label_id: "" for label_id in np.unique(model_lab_on_gt) - } - - if model_lab_on_gt.shape[0] == 0: - info.model_labels_id_and_status[ - f"not_found_{not_found_id}" - ] = "not found" - not_found_id += 1 - label_report_list.append(info) - continue - - log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") - - # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label - log.debug( - f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" - ) - - ratio = [] - for model_lab_id in info.model_labels_id_and_status.keys(): - size_model_label = ( - model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] - .flatten() - .shape[0] - ) - size_gt_label = gt_label.flatten().shape[0] - - log.debug(f"size_model_label : {size_model_label}") - log.debug(f"size_gt_label : {size_gt_label}") - - ratio.append(size_model_label / size_gt_label) - - # log.debug(ratio) - ratio_model_lab_for_given_gt_lab = np.array(ratio) - info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() - - best_model_lab_id = model_lab_on_gt[ - np.argmax(ratio_model_lab_for_given_gt_lab) - ] - log.debug(f"best_model_lab_id : {best_model_lab_id}") - - info.overall_gt_label_coverage = ( - ratio_model_lab_for_given_gt_lab.sum() - ) # the ratio of the pixels of the true label correctly labelled - - if info.best_model_label_coverage > PERCENT_CORRECT: - info.model_labels_id_and_status[best_model_lab_id] = "correct" - # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] - else: - info.model_labels_id_and_status[best_model_lab_id] = "wrong" - for model_lab_id in np.unique(model_lab_on_gt): - if model_lab_id != best_model_lab_id: - log.debug(model_lab_id, "is wrong") - info.model_labels_id_and_status[model_lab_id] = "wrong" - - label_report_list.append(info) - - correct_labels_id = [] - for report in label_report_list: - for i_lab in report.model_labels_id_and_status.keys(): - if report.model_labels_id_and_status[i_lab] == "correct": - correct_labels_id.append(i_lab) - """Find all labels in label_report_list that are correct more than once""" - duplicated_labels = [ - item for item, count in Counter(correct_labels_id).items() if count > 1 - ] - "Sum up the size of all duplicated labels" - for i in duplicated_labels: - for report in label_report_list: - if ( - i in report.model_labels_id_and_status.keys() - and report.model_labels_id_and_status[i] == "correct" - ): - size = ( - model_labels[np.where(model_labels == i)] - .flatten() - .shape[0] - ) - map_fused_neurons[i] = size - - return label_report_list, new_labels, map_fused_neurons - - -def map_labels(gt_labels, model_labels): - """Map the model's labels to the neurons labels. - Parameters - ---------- - gt_labels : ndarray - Label image with neurons labelled as mulitple values. - model_labels : ndarray - Label image from the model labelled as mulitple values. - Returns - ------- - map_labels_existing: numpy array - The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled - map_fused_neurons: numpy array - The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones - new_labels: list - The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact - """ - map_labels_existing = [] - map_fused_neurons = [] - new_labels = [] - - for i in tqdm(np.unique(model_labels)): - if i == 0: - continue - indexes = gt_labels[model_labels == i] - # find the most common labels in the label i of the model - unique, counts = np.unique(indexes, return_counts=True) - tmp_map = [] - total_pixel_found = 0 - - # log.debug(f"i: {i}") - for ii in range(len(unique)): - true_positive_ratio_model = counts[ii] / np.sum(counts) - # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found - # log.debug(f"unique: {unique[ii]}") - if unique[ii] == 0: - if true_positive_ratio_model > 0.5: - # -> artifact found - new_labels.append([i, true_positive_ratio_model]) - else: - # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, - # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found - ratio_pixel_found = counts[ii] / np.sum( - gt_labels == unique[ii] - ) - if ratio_pixel_found > 0.8: - total_pixel_found += np.sum(counts[ii]) - tmp_map.append( - [ - i, - unique[ii], - ratio_pixel_found, - true_positive_ratio_model, - ] - ) - - if len(tmp_map) == 1: - # map to only one true neuron -> found neuron - map_labels_existing.append(tmp_map[0]) - elif len(tmp_map) > 1: - # map to multiple true neurons -> fused neuron - for ii in range(len(tmp_map)): - if total_pixel_found > np.sum(counts): - raise ValueError( - f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" - ) - tmp_map[ii][3] = total_pixel_found / np.sum(counts) - map_fused_neurons += tmp_map - - # log.debug(f"map_labels_existing: {map_labels_existing}") - # log.debug(f"map_fused_neurons: {map_fused_neurons}") - # log.debug(f"new_labels: {new_labels}") - return map_labels_existing, map_fused_neurons, new_labels +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels, do_print=False, visualize=False + labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False ): """Evaluate the model performance. Parameters @@ -278,7 +18,7 @@ def evaluate_model_performance( Label image with neurons labelled as mulitple values. model_labels : ndarray Label image from the model labelled as mulitple values. - do_print : bool + print_details : bool If True, print the results. visualize : bool If True, visualize the results. @@ -305,7 +45,7 @@ def evaluate_model_performance( """ log.debug("Mapping labels...") map_labels_existing, map_fused_neurons, new_labels = map_labels( - labels, model_labels + labels, model_labels, threshold_correct ) # calculate the number of neurons individually found @@ -351,33 +91,30 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - if do_print: - log.info("Neurons found: ") - log.info(neurons_found) - log.info("Neurons fused: ") - log.info(neurons_fused) - log.info("Neurons not found: ") - log.info(neurons_not_found) - log.info("Artefacts found: ") - log.info(artefacts_found) + log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") + log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") + log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") + log.info( + f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + ) log.info( - "Mean true positive ratio of the model: ", + f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" ) - log.info(mean_true_positive_ratio_model) log.info( - "Mean ratio of the neurons pixels correctly labelled: ", + f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" ) - log.info(mean_ratio_pixel_found) log.info( - "Mean ratio of the neurons pixels correctly labelled for fused neurons: ", + f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" ) - log.info(mean_ratio_pixel_found_fused) log.info( - "Mean true positive ratio of the model for fused neurons: ", + f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" ) - log.info(mean_true_positive_ratio_model_fused) - log.info("Mean ratio of false pixel in artefacts: ") - log.info(mean_ratio_false_pixel_artefact) if visualize: viewer = napari.Viewer() @@ -436,6 +173,81 @@ def evaluate_model_performance( ) +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + def save_as_csv(results, path): """ Save the results as a csv file @@ -464,6 +276,192 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons # if __name__ == "__main__": # """ diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index d521c395..4bf89452 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -4,9 +4,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -22,6 +19,7 @@ " binary_connected,\n", " binary_watershed,\n", " voronoi_otsu,\n", + " to_semantic,\n", ")" ] }, @@ -29,9 +27,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "pycharm": { - "is_executing": true - }, "tags": [] }, "outputs": [], @@ -50,12 +45,14 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(25, 64, 64)\n", - "(25, 64, 64)\n" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -72,9 +69,7 @@ "\n", "\n", "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", - "viewer.add_labels(gt_labels_resized, name=\"gt\")\n", - "print(prediction_resized.shape)\n", - "print(gt_labels_resized.shape)" + "viewer.add_labels(gt_labels_resized, name=\"gt\")" ] }, { @@ -84,9 +79,33 @@ "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false } }, "outputs": [], @@ -98,7 +117,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -110,48 +143,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,112 - Mapping labels...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 3913.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-03-22 14:47:30,150 - Calculating the number of neurons not found...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "2\n" ] - }, - { - "data": { - "text/plain": [ - "(124, 0, 0, 0, 1.0, 1.0, nan, nan, nan)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)\n", + "print(np.unique(gt_labels_resized).shape[0])" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -162,23 +168,22 @@ { "data": { "text/plain": [ - "dtype('int32')" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "connected = binary_connected(prediction_resized)\n", - "viewer.add_labels(connected, name=\"connected\")\n", - "connected.dtype" + "connected = binary_connected(prediction_resized,thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -190,21 +195,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,231 - Mapping labels...\n" + "2023-03-22 15:48:05,891 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:00<00:00, 3069.93it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,265 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -217,18 +225,10 @@ { "data": { "text/plain": [ - "(45,\n", - " 38,\n", - " 41,\n", - " 8,\n", - " 0.8424215218790255,\n", - " 0.9295584641244191,\n", - " 0.9332428837640889,\n", - " 0.8300682812799907,\n", - " 1.0)" + "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -251,21 +251,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,344 - Mapping labels...\n" + "2023-03-22 15:48:05,995 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 2864.94it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,376 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -278,25 +281,17 @@ { "data": { "text/plain": [ - "(47,\n", - " 37,\n", - " 40,\n", - " 0,\n", - " 0.8426909426266451,\n", - " 0.9355733839977192,\n", - " 0.9369165405470844,\n", - " 0.8198206611354032,\n", - " nan)" + "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "watershed = binary_watershed(\n", - " prediction_resized, thres_small=20, rem_seed_thres=5\n", + " prediction_resized, thres_small=2, rem_seed_thres=1\n", ")\n", "viewer.add_labels(watershed)\n", "eval.evaluate_model_performance(gt_labels_resized, watershed)" @@ -304,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { @@ -318,24 +313,24 @@ "(25, 64, 64)" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=0.5)\n", + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", - "voronoi = remove_small_objects(voronoi, 10)\n", + "voronoi = remove_small_objects(voronoi, 2)\n", "viewer.add_labels(voronoi)\n", "voronoi.shape" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { @@ -349,7 +344,7 @@ "dtype('int64')" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -360,104 +355,35 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 2, 3, 7, 10, 11, 12, 14, 15, 16, 17, 18, 19,\n", - " 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48,\n", - " 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,\n", - " 77, 79, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 92,\n", - " 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107,\n", - " 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,\n", - " 122], dtype=uint32),\n", - " array([97911, 46, 18, 10, 47, 57, 44, 54, 22,\n", - " 94, 36, 42, 51, 41, 35, 42, 46, 47,\n", - " 52, 31, 12, 45, 68, 106, 69, 56, 32,\n", - " 69, 46, 75, 78, 41, 42, 42, 50, 50,\n", - " 48, 50, 44, 49, 50, 54, 58, 43, 41,\n", - " 39, 49, 15, 33, 25, 44, 52, 32, 81,\n", - " 29, 46, 42, 46, 34, 30, 34, 36, 57,\n", - " 21, 26, 51, 40, 49, 34, 46, 45, 13,\n", - " 28, 37, 44, 31, 46, 47, 42, 40, 42,\n", - " 47, 116, 44, 32, 33, 28, 27, 41, 35,\n", - " 37, 41, 38, 33, 50, 25, 33, 80, 19,\n", - " 28, 36, 28, 14, 31, 54], dtype=int64))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(voronoi, return_counts=True)" + "# np.unique(voronoi, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n", - " 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,\n", - " 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,\n", - " 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,\n", - " 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,\n", - " 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,\n", - " 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,\n", - " 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,\n", - " 118, 119, 120, 121, 122, 123, 124, 125], dtype=int64),\n", - " array([97748, 4, 4, 4, 91, 4, 29, 54, 25,\n", - " 59, 26, 18, 4, 37, 39, 21, 4, 38,\n", - " 33, 47, 67, 31, 16, 4, 61, 92, 50,\n", - " 43, 29, 26, 38, 22, 39, 28, 32, 43,\n", - " 32, 60, 48, 16, 36, 77, 64, 41, 41,\n", - " 54, 29, 22, 46, 47, 26, 17, 31, 76,\n", - " 28, 58, 40, 57, 35, 19, 41, 47, 48,\n", - " 60, 44, 46, 33, 46, 53, 54, 71, 8,\n", - " 26, 45, 20, 55, 26, 43, 62, 55, 54,\n", - " 43, 51, 33, 43, 45, 36, 22, 22, 52,\n", - " 24, 44, 36, 60, 42, 31, 59, 8, 34,\n", - " 43, 57, 54, 46, 69, 42, 47, 25, 18,\n", - " 41, 41, 56, 4, 48, 4, 66, 14, 25,\n", - " 33, 25, 7, 5, 7, 19, 32, 40],\n", - " dtype=int64))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "np.unique(gt_labels_resized, return_counts=True)" + "# np.unique(gt_labels_resized, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { @@ -469,21 +395,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,755 - Mapping labels...\n" + "2023-03-22 15:48:06,360 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [00:00<00:00, 3290.07it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 14:47:30,792 - Calculating the number of neurons not found...\n" + "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", + "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" ] }, { @@ -496,18 +425,10 @@ { "data": { "text/plain": [ - "(72,\n", - " 8,\n", - " 44,\n", - " 1,\n", - " 0.8348479609766444,\n", - " 0.9314226186350036,\n", - " 0.9483750072126669,\n", - " 0.8528417100412058,\n", - " 1.0)" + "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -518,14 +439,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false - }, - "pycharm": { - "is_executing": true } }, "outputs": [], From 54519ac87b2660c47095ff810773a94601703aea Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:39:55 +0100 Subject: [PATCH 483/577] Added pre-commit hooks --- .pre-commit-config.yaml | 44 +++++++++++++++++++++++++++++------------ requirements.txt | 2 ++ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..802dfe20 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,44 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: +# - repo: https://github.com/pre-commit/pre-commit-hooks +# rev: v4.0.1 +# hooks: # - id: check-docstring-first - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] +# - id: end-of-file-fixer +# - id: trailing-whitespace +# - repo: https://github.com/asottile/setup-cfg-fmt +# rev: v1.20.0 +# hooks: +# - id: setup-cfg-fmt +# - repo: https://github.com/PyCQA/flake8 +# rev: 4.0.1 +# hooks: +# - id: flake8 +# additional_dependencies: [flake8-typing-imports>=1.9.0] +# - repo: https://github.com/myint/autoflake +# rev: v1.4 +# hooks: +# - id: autoflake +# args: ["--in-place", "--remove-all-unused-imports"] +# - repo: https://github.com/PyCQA/isort +# rev: 5.10.1 +# hooks: +# - id: isort - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.262' + rev: 'v0.0.257' hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 22.3.0 hooks: - id: black - args: [--line-length=79] + args: [--line-length=88] +# - repo: https://github.com/asottile/pyupgrade +# rev: v2.29.1 +# hooks: +# - id: pyupgrade +# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks rev: v0.3.0 hooks: diff --git a/requirements.txt b/requirements.txt index 92aae176..a7dd1570 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,9 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pre-commit pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 ruff From 9887cc71043ac5445ef281220bca719691f8e9c1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:40:31 +0100 Subject: [PATCH 484/577] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 802dfe20..d1e22fb1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: rev: 22.3.0 hooks: - id: black - args: [--line-length=88] + args: [--line-length=79] # - repo: https://github.com/asottile/pyupgrade # rev: v2.29.1 # hooks: From 2b0bb6d1d67d2c75c8672cd540d2d0e226c9e7af Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:48:32 +0100 Subject: [PATCH 485/577] Update pyproject.toml --- pyproject.toml | 52 ++------------------------------------------------ 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c9a9d942..0003b804 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,57 +34,9 @@ dependencies = [ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" -[tool.setuptools] -include-package-data = true - -[tool.setuptools.packages.find] -where = ["."] - -[tool.setuptools.package-data] -"*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] - [tool.ruff] -select = [ - "E", "F", "W", - "A", - "B", - "G", - "I", - "PT", - "PTH", - "RET", - "SIM", - "TCH", - "NPY", -] -# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) -# and 'G004' (do not use f-strings in logging) -ignore = ["E501", "E741", "G004", "A003"] -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".mypy_cache", - ".nox", - ".pants.d", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "venv", - "docs/conf.py", - "napari_cellseg3d/_tests/conftest.py", -] +# Never enforce `E501` (line length violations). +ignore = ["E501"] [tool.black] line-length = 79 From c978043ad739245be686ef31d57532015be25929 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:50:33 +0100 Subject: [PATCH 486/577] Update pyproject.toml Ruff config --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0003b804..d80555c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ build-backend = "setuptools.build_meta" [tool.ruff] # Never enforce `E501` (line length violations). -ignore = ["E501"] +ignore = ["E501", "E741"] [tool.black] line-length = 79 From 2bf949417f0921bfbe087abf0e44acf6eb76658d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 22 Mar 2023 16:52:48 +0100 Subject: [PATCH 487/577] Enfore pre-commit style --- .gitignore | 6 -- .../_tests/test_plugin_inference.py | 2 - napari_cellseg3d/_tests/test_utils.py | 15 +--- .../code_models/instance_segmentation.py | 7 +- .../code_models/models/model_TRAILMAP.py | 28 ++++++- .../code_models/models/model_TRAILMAP_MS.py | 21 +---- napari_cellseg3d/code_models/workers.py | 64 ++------------- .../code_plugins/plugin_convert.py | 2 +- .../code_plugins/plugin_model_inference.py | 56 ++++++------- napari_cellseg3d/config.py | 9 +-- .../dev_scripts/artefact_labeling.py | 1 - .../dev_scripts/correct_labels.py | 1 - .../dev_scripts/evaluate_labels.py | 23 ++++-- napari_cellseg3d/dev_scripts/thread_test.py | 2 +- napari_cellseg3d/interface.py | 2 +- notebooks/assess_instance.ipynb | 79 +++++++++++++------ 16 files changed, 140 insertions(+), 178 deletions(-) diff --git a/.gitignore b/.gitignore index df67a187..427603f1 100644 --- a/.gitignore +++ b/.gitignore @@ -105,9 +105,3 @@ notebooks/full_plot.html *.csv *.png notebooks/instance_test.ipynb -*.prof - -#include test data -!napari_cellseg3d/_tests/res/test.tif -!napari_cellseg3d/_tests/res/test.png -!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index a17120c3..fbeb9943 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -8,8 +8,6 @@ from napari_cellseg3d.config import MODEL_LIST -def test_inference(make_napari_viewer, qtbot): - def test_inference(make_napari_viewer, qtbot): im_path = str(Path(__file__).resolve().parent / "res/test.tif") image = imread(im_path) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index dc680b35..12720688 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -9,7 +9,7 @@ def test_fill_list_in_between(): - test_list = [1, 2, 3, 4, 5, 6] + list = [1, 2, 3, 4, 5, 6] res = [ 1, "", @@ -37,7 +37,6 @@ def test_fill_list_in_between(): assert fill(test_list) == res - def test_align_array_sizes(): im = np.zeros((128, 512, 256)) print(im.shape) @@ -110,16 +109,8 @@ def test_normalize_x(): def test_parse_default_path(): - user_path = Path().home() - assert utils.parse_default_path([None]) == str(user_path) - - test_path = "C:/test/test" - path = [test_path, None, None] - assert utils.parse_default_path(path) == test_path - - long_path = "D:/very/long/path/what/a/bore/ifonlytherewassomethingtohelpmenottypeitiallthetime" - path = [test_path, None, None, long_path, ""] - assert utils.parse_default_path(path) == long_path + user_path = os.path.expanduser("~") + assert utils.parse_default_path([None]) == user_path def test_thread_test(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index f8fc2517..d5f10584 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -5,12 +5,10 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.filters import thresholding from skimage.measure import label from skimage.measure import regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed -from skimage.transform import resize # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes @@ -574,14 +572,13 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError as e: - logger.debug(f"Caught runtime error, most likely during testing") + except RuntimeError: + logger.debug("Caught runtime error, most likely during testing") self.setLayout(group.layout) self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index e6bbad55..0d9ebace 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -2,6 +2,26 @@ from torch import nn +def get_weights_file(): + # model additionally trained on Mathis/Wyss mesoSPIM data + return "TRAILMAP_PyTorch.pth" + # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them + + +def get_net(): + return TRAILMAP(1, 1) + + +def get_output(model, input): + out = model(input) + + return out + + +def get_validation(model, val_inputs): + return model(val_inputs) + + class TRAILMAP(nn.Module): def __init__(self, in_ch, out_ch, *args, **kwargs): super().__init__() @@ -44,7 +64,7 @@ def forward(self, x): # print(out.shape) def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - return nn.Sequential( + encode = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -57,7 +77,7 @@ def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): ) def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - return nn.Sequential( + encode = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -69,7 +89,7 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): ) def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - return nn.Sequential( + decode = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -84,7 +104,7 @@ def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): ) def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - return nn.Sequential( + out = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index baf8635d..73f842b1 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -8,23 +8,8 @@ class TRAILMAP_MS_(UNet3D): use_default_training = True weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth" - # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) + return out - def __init__(self, in_channels=1, out_channels=1, **kwargs): - try: - super().__init__( - in_channels=in_channels, out_channels=out_channels, **kwargs - ) - except TypeError as e: - logger.warning(f"Caught TypeError: {e}") - super().__init__( - in_channels=in_channels, out_channels=out_channels - ) - # def get_output(self, input): - # out = self(input) - - # return out - # - # def get_validation(self, val_inputs): - # return self(val_inputs) +def get_validation(model, val_inputs): + return model(val_inputs) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 2de03dd8..4462db41 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -315,21 +315,6 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) - def raise_error(self, exception, msg): - """Raises an error in main thread""" - logger.error(msg, exc_info=True) - logger.error(exception, exc_info=True) - - self.log_signal.emit("!" * 20) - self.log_signal.emit("Error occured") - # self.log_signal.emit(msg) - # self.log_signal.emit(str(exception)) - - self.error_signal.emit(exception, msg) - self.errored.emit(exception) - yield exception - # self.quit() - def log_parameters(self): config = self.config @@ -505,11 +490,10 @@ def model_output( ): inputs = inputs.to("cpu") - model_output = lambda inputs: post_process_transforms( - self.config.model_info.get_model().get_output( - model, inputs - ) # TODO(cyril) refactor those functions - ) + # def model_output(inputs): + # return post_process_transforms( + # self.config.model_info.get_model().get_output(model, inputs) + # ) def model_output(inputs): return post_process_transforms( @@ -698,11 +682,10 @@ def aniso_transform(self, image): padding_mode="empty", ) return anisotropic_transform(image[0]) - return image + else: + return image - def instance_seg( - self, semantic_labels, image_id=0, original_filename="layer" - ): + def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") @@ -1095,14 +1078,6 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) - def raise_error(self, exception, msg): - """Sends an error to main thread""" - logger.error(msg, exc_info=True) - logger.error(exception, exc_info=True) - self.error_signal.emit(exception, msg) - self.errored.emit(exception) - self.quit() - def log_parameters(self): self.log("-" * 20) self.log("Parameters summary :\n") @@ -1314,31 +1289,6 @@ def train(self): ) # self.log("Loading dataset...\n") - def get_loader_func(num_samples): - return Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), - ), - EnsureTyped(keys=["image", "label"]), - ] - ) - if do_sampling: # if there is only one volume, split samples # TODO(cyril) : maybe implement something in user config to toggle this behavior diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 2dc8f07c..0d56372a 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -139,7 +139,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + self.results_path.mkdir(exist_ok=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index e4a5af02..75336f26 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -15,9 +15,6 @@ InstanceWidgets, ) from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( - INSTANCE_SEGMENTATION_METHOD_LIST, -) from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets from napari_cellseg3d.code_models.model_workers import InferenceResult @@ -834,27 +831,23 @@ def on_yield(self, result: InferenceResult): ): zoom = self.worker_config.post_process_config.zoom.zoom_values - viewer.dims.ndisplay = 3 - viewer.scale_bar.visible = True - - if self.config.show_original and result.original is not None: - viewer.add_image( - result.original, - colormap="inferno", - name=f"original_{image_id}", - scale=zoom, - opacity=0.7, - ) + if ( + self.config.show_results + and image_id <= self.config.show_results_count + ): + zoom = self.worker_config.post_process_config.zoom.zoom_values out_colormap = "twilight" if self.worker_config.post_process_config.thresholding.enabled: out_colormap = "turbo" + if self.config.show_original and result.original is not None: viewer.add_image( - result.result, - colormap=out_colormap, - name=f"pred_{image_id}_{model_name}", - opacity=0.8, + result.original, + colormap="inferno", + name=f"original_{image_id}", + scale=zoom, + opacity=0.7, ) if result.crf_results is not None: logger.debug( @@ -873,9 +866,16 @@ def on_yield(self, result: InferenceResult): self.worker_config.post_process_config.instance.method.name ) - number_cells = ( - np.unique(result.instance_labels.flatten()).size - 1 - ) # remove background + out_colormap = "twilight" + if self.worker_config.post_process_config.thresholding.enabled: + out_colormap = "turbo" + + viewer.add_image( + result.result, + colormap=out_colormap, + name=f"pred_{image_id}_{model_name}", + opacity=0.8, + ) if result.instance_labels is not None: labels = result.instance_labels @@ -889,22 +889,16 @@ def on_yield(self, result: InferenceResult): name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + viewer.add_labels(labels, name=name) if result.stats is not None and isinstance( result.stats, list ): log.debug(f"len stats : {len(result.stats)}") - for i, stats in enumerate(result.stats): - # stats = result.stats - - if ( - self.worker_config.compute_stats - and stats is not None - ): - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if self.worker_config.compute_stats and stats is not None: + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) self.log.print_and_log( f"Number of instances in channel {i} : {stats.number_objects[0]}" diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index e1dd57ef..89d88c52 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,10 +6,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ConnectedComponents from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import VoronoiOtsu -from napari_cellseg3d.code_models.model_instance_seg import Watershed # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet @@ -94,9 +91,9 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info("Model list :") - for model_name in MODEL_LIST: - logger.info(f" * {model_name}") + logger.info( + "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) + ) return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b66ace64..9a344545 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -417,7 +417,6 @@ def create_artefact_labels_from_folder( if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] print(f"REPO PATH : {repo_path}") paths = [ diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index da938c01..cd09754e 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -335,7 +335,6 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") image_path = str(im_path / "image.tif") gt_labels_path = str(im_path / "labels.tif") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index 3082e79f..a972fa69 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -5,11 +5,15 @@ from napari_cellseg3d.utils import LOGGER as log -PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct def evaluate_model_performance( - labels, model_labels,threshold_correct = PERCENT_CORRECT , print_details=False, visualize=False + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, ): """Evaluate the model performance. Parameters @@ -91,9 +95,15 @@ def evaluate_model_performance( else: mean_ratio_false_pixel_artefact = np.nan - log.info(f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%") - log.info(f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%") - log.info(f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%") + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) if print_details: log.info(f"Neurons found: {neurons_found}") @@ -131,7 +141,7 @@ def evaluate_model_performance( ) viewer.add_labels(found_label, name="ground truth found") neurones_not_found_labels = np.where( - np.isin(unique_labels, neurons_found_labels) == False, + np.isin(unique_labels, neurons_found_labels) is False, unique_labels, 0, ) @@ -276,6 +286,7 @@ def save_as_csv(results, path): ) df.to_csv(path, index=False) + ####################### # Slower version that was used for debugging ####################### diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index a48f6db0..dd3ff4e5 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -131,7 +131,7 @@ def on_finish(): if __name__ == "__main__": - viewer = napari.view_image(rand_gen.random(512, 512)) + viewer = napari.view_image(np.random.rand(512, 512)) w = create_connected_widget(viewer) viewer.window.add_dock_widget(w) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index c64cca19..50f19269 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1077,7 +1077,7 @@ def __init__( step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed - text_label (Optional[str]): if provided, creates a label with the chosen title to use with the counter + label (Optional[str]): if provided, creates a label with the chosen title to use with the counter """ super().__init__(parent) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 4bf89452..b8810301 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -47,7 +47,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -96,7 +96,10 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", - "dice_coeff(to_semantic(gt_labels_resized.copy()), to_semantic(prediction_resized.copy()))" + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" ] }, { @@ -145,7 +148,7 @@ "text": [ "(25, 64, 64)\n", "(25, 64, 64)\n", - "2\n" + "125\n" ] } ], @@ -168,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -177,7 +180,7 @@ } ], "source": [ - "connected = binary_connected(prediction_resized,thres_small=2)\n", + "connected = binary_connected(prediction_resized, thres_small=2)\n", "viewer.add_labels(connected, name=\"connected\")" ] }, @@ -195,24 +198,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,891 - Mapping labels...\n" + "2023-03-22 15:48:47,057 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4352.70it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,919 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:05,921 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:05,922 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -225,7 +228,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 12, nan, nan, nan, nan, 1.0)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -251,24 +262,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:05,995 - Mapping labels...\n" + "2023-03-22 15:48:47,168 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4354.19it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,021 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,023 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,023 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,024 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -281,7 +292,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 10, nan, nan, nan, nan, 1.0)" + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -395,24 +414,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,360 - Mapping labels...\n" + "2023-03-22 15:48:47,570 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 4153.73it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:06,392 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:06,394 - Percent of non-fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Percent of fused neurons found: 0.00%\n", - "2023-03-22 15:48:06,395 - Overall percent of neurons found: 0.00%\n" + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -425,7 +444,15 @@ { "data": { "text/plain": [ - "(0, 0, 1, 17, nan, nan, nan, nan, 0.7306099091287442)" + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" ] }, "execution_count": 15, From 51a6c3579144f46ec889dd6d8b08d4298f695fed Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:30:55 +0200 Subject: [PATCH 488/577] Update .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 427603f1..ee1bf4a0 100644 --- a/.gitignore +++ b/.gitignore @@ -104,4 +104,4 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png -notebooks/instance_test.ipynb + From c2e90f7637f219faeb5403801fbf37b7ecd1bc40 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:32:56 +0200 Subject: [PATCH 489/577] Version bump --- napari_cellseg3d/__init__.py | 2 +- napari_cellseg3d/code_plugins/plugin_helper.py | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 11e8de0e..2c537225 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc6" +__version__ = "0.0.2rc2" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 552f70ea..7149d8cc 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -39,7 +39,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.3rc1'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 9045eec4..3d926ea4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.3rc1 +version = 0.0.2rc2 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 235421128758ef7bb12328f8e8e3b0d2e99874ad Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 11 Apr 2023 11:33:40 +0200 Subject: [PATCH 490/577] Revert "Version bump" This reverts commit 6e39971b39fb926084f3ed71d82e8c25f68f8b6f. --- napari_cellseg3d/__init__.py | 2 +- napari_cellseg3d/code_plugins/plugin_helper.py | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 2c537225..6e2681e8 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc2" +__version__ = "0.0.2rc1" diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 7149d8cc..54c34a8f 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -39,7 +39,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc2'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/setup.cfg b/setup.cfg index 3d926ea4..c9826f06 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc2 +version = 0.0.2rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 22868a5cedc25ee4719ce3eddc2be020e6f25072 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Apr 2023 09:43:27 +0200 Subject: [PATCH 491/577] Updated project files --- pyproject.toml | 28 +++++++++++----------------- setup.cfg | 2 +- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d80555c2..5dec250c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "napari_cellseg3d" -version = "0.0.3rc1" +version = "0.0.2rc6" authors = [ {name = "Cyril Achard", email = "cyril.achard@epfl.ch"}, {name = "Maxime Vidal", email = "maxime.vidal@epfl.ch"}, @@ -34,6 +34,15 @@ dependencies = [ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] + +[tool.setuptools.package-data] +"*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] + [tool.ruff] # Never enforce `E501` (line length violations). ignore = ["E501", "E741"] @@ -46,15 +55,10 @@ profile = "black" line_length = 79 [project.optional-dependencies] -crf = [ - "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", -] dev = [ "isort", "black", "ruff", - "pre-commit", - ] docs = [ "sphinx", @@ -65,17 +69,7 @@ docs = [ test = [ "pytest", "pytest_qt", - "pytest-cov", "coverage", "tox", "twine", - "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", -] -onnx-cpu = [ - "onnx", - "onnxruntime" -] -onnx-gpu = [ - "onnx", - "onnxruntime-gpu" -] +] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index c9826f06..2420dd1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc1 +version = 0.0.2rc6 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 0e403b817027b54a7d6dc93de5195816c0659dfa Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 09:45:17 +0200 Subject: [PATCH 492/577] Fixed missing parent error --- napari_cellseg3d/code_models/instance_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index d5f10584..c10da6ec 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -482,7 +482,7 @@ def run_method(self, image): class VoronoiOtsu(InstanceMethod): """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" - def __init__(self, widget_parent): + def __init__(self, widget_parent=None): super().__init__( name=VORONOI_OTSU, function=voronoi_otsu, From b6fc3f6d6c389c690aa74f081bd63e35ac899583 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 15 Apr 2023 10:40:19 +0200 Subject: [PATCH 493/577] Fixed wrong value in instance sliders --- napari_cellseg3d/code_models/instance_segmentation.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index c10da6ec..c1161df1 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -163,6 +163,9 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) instance = cle.voronoi_otsu_labeling( volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma ) @@ -562,7 +565,7 @@ def _build(self): method_class = method(widget_parent=self.parent()) self.methods[name] = method_class self.instance_widgets[name] = [] - # moderately unsafe way to init those widgets + # moderately unsafe way to init those widgets ? if len(method_class.sliders) > 0: for slider in method_class.sliders: group.layout.addWidget(slider.container) @@ -572,8 +575,10 @@ def _build(self): group.layout.addWidget(counter.label) group.layout.addWidget(counter) self.instance_widgets[name].append(counter) - except RuntimeError: - logger.debug("Caught runtime error, most likely during testing") + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) self.setLayout(group.layout) self._set_visibility() From 11e0dccb5b1ad2082cb66c4ae8b2dad2a3c8498b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 10:44:04 +0200 Subject: [PATCH 494/577] Removing dask-image --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ee1bf4a0..0ec12b01 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ notebooks/full_plot.html *.csv *.png +*.prof From 5db5b785bfc6a41b2149ef717ee0e5d4e7b01441 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 17:20:52 +0200 Subject: [PATCH 495/577] Fixed erroneous dtype conversion --- napari_cellseg3d/code_models/instance_segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index c1161df1..6914a9e2 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -162,7 +162,7 @@ def voronoi_otsu( """ # remove_small_size (float): remove all objects smaller than the specified size in pixels - semantic = np.squeeze(volume) + # semantic = np.squeeze(volume) logger.debug( f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" ) @@ -515,6 +515,7 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): + ################ # For debugging # import napari From e80a655740dd13dde4f335dd3fe690f0b35119aa Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:28:30 +0200 Subject: [PATCH 496/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 5d5ada20..b513e1be 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,7 +1,6 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS @@ -25,9 +24,4 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) - if utils_name == "Convert to instance labels": - # to avoid issues with Voronoi-Otsu missing runtime - menu = widget.utils_widgets[i].instance_widgets.method_choice - menu.setCurrentIndex(menu.currentIndex() + 1) - widget.utils_widgets[i]._start() From deb459260a0fc328726c035410b9a73a6a02fefb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:50:16 +0200 Subject: [PATCH 497/577] Update plugin_convert.py --- napari_cellseg3d/code_plugins/plugin_convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 0d56372a..2dc8f07c 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -139,7 +139,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): From d6c359ba27b9ba3e9b34e9d22ed1d210b4222751 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:02:47 +0200 Subject: [PATCH 498/577] Update tox.ini Added pocl for testing on GH Actions --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 4b04a5bc..8aaea25e 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy -; pyopencl[pocl] + pocl ; opencv-python extras = crf usedevelop = true From c0cbbbe92456edebb585dc1220d3bfc82af87f68 Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:07:58 +0200 Subject: [PATCH 499/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 8aaea25e..22b09bf5 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl + pocl-binary-distribution ; opencv-python extras = crf usedevelop = true From 80f6ea3bf33cc07535ef00abac7f3242d79ad3d2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:18:52 +0200 Subject: [PATCH 500/577] Found existing pocl --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 22b09bf5..82fa219b 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pocl-binary-distribution + pyopencl[pocl] ; opencv-python extras = crf usedevelop = true From 5e5f63c385f98f7018d721d2bbcd5943ae6960a2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 11:41:23 +0200 Subject: [PATCH 501/577] Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime --- napari_cellseg3d/_tests/test_plugin_utils.py | 5 +++++ tox.ini | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index b513e1be..7ca0555f 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -24,4 +24,9 @@ def test_utils_plugin(make_napari_viewer): assert isinstance( widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] ) + if utils_name == "Convert to instance labels": + # to avoid issues with Voronoi-Otsu missing runtime + menu = widget.utils_widgets[i].instance_widgets.method_choice + menu.setCurrentIndex(menu.currentIndex() + 1) + widget.utils_widgets[i]._start() diff --git a/tox.ini b/tox.ini index 82fa219b..4b04a5bc 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pyopencl[pocl] +; pyopencl[pocl] ; opencv-python extras = crf usedevelop = true From 94384db167cebafc84cb2128e74df1c67fd426c3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 13:40:19 +0200 Subject: [PATCH 502/577] Relabeling tests --- .gitignore | 6 +- .../_tests/test_labels_correction.py | 3 +- .../dev_scripts/artefact_labeling.py | 93 +++++++++---------- .../dev_scripts/correct_labels.py | 75 ++++++++++----- 4 files changed, 102 insertions(+), 75 deletions(-) diff --git a/.gitignore b/.gitignore index 0ec12b01..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -104,5 +104,9 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png - *.prof + +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..9d4e7801 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,7 +1,6 @@ from pathlib import Path - -import numpy as np from tifffile import imread +import numpy as np from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 9a344545..bf724a46 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,7 +1,5 @@ import numpy as np -from tifffile import imread -from tifffile import imwrite -from pathlib import Path +from tifffile import imwrite, imread import scipy.ndimage as ndimage import os import napari @@ -64,7 +62,7 @@ def map_labels(labels, artefacts): def make_labels( - path_image, + image, path_labels_out, threshold_factor=1, threshold_size=30, @@ -76,7 +74,7 @@ def make_labels( """Detect nucleus. using a binary watershed algorithm and otsu thresholding. Parameters ---------- - path_image : str + image : str Path to image. path_labels_out : str Path of the output labelled image. @@ -96,7 +94,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(path_image) + # image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -126,28 +124,26 @@ def make_labels( ) -def select_image_by_labels( - path_image, path_labels, path_image_out, label_values -): +def select_image_by_labels(image, labels, path_image_out, label_values): """Select image by labels. Parameters ---------- - path_image : str - Path to image. - path_labels : str - Path to labels. + image : np.array + image. + labels : np.array + labels. path_image_out : str Path of the output image. label_values : list List of label values to select. """ - image = imread(path_image) - labels = imread(path_labels) + # image = imread(image) + # labels = imread(labels) image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) -# select the smalles cube that contains all the none zero pixel of an 3d image +# select the smallest cube that contains all the non-zero pixels of a 3d image def get_bounding_box(img): height = np.any(img, axis=(0, 1)) rows = np.any(img, axis=(0, 2)) @@ -165,16 +161,15 @@ def crop_image(img): return img[xmin:xmax, ymin:ymax, zmin:zmax] -def crop_image_path(path_image, path_image_out): +def crop_image_path(image, path_image_out): """Crop image. Parameters ---------- - path_image : str - Path to image. + image : np.array + image path_image_out : str Path of the output image. """ - image = imread(path_image) image = crop_image(image) imwrite(path_image_out, image.astype(np.float32)) @@ -307,8 +302,8 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): def create_artefact_labels( - image_path, - labels_path, + image, + labels, output_path, threshold_artefact_brightness_percent=40, threshold_artefact_size_percent=1, @@ -317,10 +312,10 @@ def create_artefact_labels( """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. Parameters ---------- - image_path : str - Path to image file. - labels_path : str - Path to label image file with each neurons labelled as a different value. + image : np.array + image for artefact detection. + labels : np.array + label image array with each neurons labelled as a different int value. output_path : str Path to save the output label image file. threshold_artefact_brightness_percent : int, optional @@ -330,9 +325,6 @@ def create_artefact_labels( contrast_power : int, optional Power for contrast enhancement. """ - image = imread(image_path) - labels = imread(labels_path) - artefacts = make_artefact_labels( image, labels, @@ -352,11 +344,12 @@ def visualize_images(paths): Parameters ---------- paths : list - List of paths to images to visualize. + List of images to visualize. """ viewer = napari.Viewer(ndisplay=3) for path in paths: - viewer.add_image(imread(path), name=os.path.basename(path)) + image = imread(path) + viewer.add_image(image) # wait for the user to close the viewer napari.run() @@ -416,22 +409,22 @@ def create_artefact_labels_from_folder( ) -if __name__ == "__main__": - repo_path = Path(__file__).resolve().parents[1] - print(f"REPO PATH : {repo_path}") - paths = [ - "dataset_clean/cropped_visual/train", - "dataset_clean/cropped_visual/val", - "dataset_clean/somatomotor", - "dataset_clean/visual_tif", - ] - for data_path in paths: - path = str(repo_path / data_path) - print(path) - create_artefact_labels_from_folder( - path, - do_visualize=False, - threshold_artefact_brightness_percent=20, - threshold_artefact_size_percent=1, - contrast_power=20, - ) +# if __name__ == "__main__": +# repo_path = Path(__file__).resolve().parents[1] +# print(f"REPO PATH : {repo_path}") +# paths = [ +# "dataset_clean/cropped_visual/train", +# "dataset_clean/cropped_visual/val", +# "dataset_clean/somatomotor", +# "dataset_clean/visual_tif", +# ] +# for data_path in paths: +# path = str(repo_path / data_path) +# print(path) +# create_artefact_labels_from_folder( +# path, +# do_visualize=False, +# threshold_artefact_brightness_percent=20, +# threshold_artefact_size_percent=1, +# contrast_power=20, +# ) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index cd09754e..50f2e47a 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -4,6 +4,7 @@ import scipy.ndimage as ndimage import napari from pathlib import Path +from functools import partial import time import warnings from napari.qt.threading import thread_worker @@ -85,13 +86,16 @@ def add_label(old_label, artefact, new_label_path, i_labels_to_add): returns = [] -def ask_labels(unique_artefact): +def ask_labels(unique_artefact, test=False): global returns returns = [] - i_labels_to_add_tmp = input( - "Which labels do you want to add (0 to skip) ? (separated by a comma):" - ) - i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + if not test: + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + else: + i_labels_to_add_tmp = [0] if i_labels_to_add_tmp == [0]: print("no label added") @@ -135,7 +139,13 @@ def ask_labels(unique_artefact): def relabel( - image_path, label_path, go_fast=False, check_for_unicity=True, delay=0.3 + image_path, + label_path, + go_fast=False, + check_for_unicity=True, + delay=0.3, + viewer=None, + test=False, ): """relabel the image labelled with different label for each neuron and save it in the save_path location Parameters @@ -150,6 +160,8 @@ def relabel( if True, the relabeling will check if the labels are unique, by default True delay : float, optional the delay between each image for the visualization, by default 0.3 + viewer : napari.Viewer, optional + the napari viewer, by default None """ global returns @@ -164,9 +176,10 @@ def relabel( print( "visualize the relabeld image in white the previous labels and in red the new labels" ) - visualize_map( - map_labels_existing, label_path, new_label_path, delay=delay - ) + if not test: + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) label_path = new_label_path # detect artefact print("detection of potential neurons (in progress)") @@ -186,15 +199,22 @@ def relabel( unique_artefact = list(np.unique(artefact)) while loop: # visualize the artefact and ask the user which label to add to the label image - t = threading.Thread(target=ask_labels, args=(unique_artefact,)) + t = threading.Thread( + target=partial(ask_labels, test=test), args=(unique_artefact,) + ) t.start() artefact_copy = np.where( np.isin(artefact, i_labels_to_add), 0, artefact ) - viewer = napari.view_image(image) + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + viewer.add_image(image, name="image") viewer.add_labels(artefact_copy, name="potential neurons") viewer.add_labels(imread(label_path), name="labels") - napari.run() + if not test: + napari.run() t.join() i_labels_to_add_tmp = returns[0] # check if the selected labels are neurones @@ -205,15 +225,26 @@ def relabel( np.isin(artefact, i_labels_to_add_tmp), artefact, 0 ) print("these labels will be added") - viewer = napari.view_image(image) - viewer.add_labels(artefact_copy, name="labels added") - napari.run() - revert = input("Do you want to revert? (y/n)") + if test: + viewer.close() + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + if not test: + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if test: + revert = "n" + viewer.close() if revert != "y": i_labels_to_add = i_labels_to_add_tmp for i in i_labels_to_add: if i in unique_artefact: unique_artefact.remove(i) + if test: + break loop = input("Do you want to add more labels? (y/n)") == "y" # add the label to the label image new_label_path = initial_label_path[:-4] + "_new_label.tif" @@ -334,9 +365,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -if __name__ == "__main__": - im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") - image_path = str(im_path / "image.tif") - gt_labels_path = str(im_path / "labels.tif") - - relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# if __name__ == "__main__": +# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") +# image_path = str(im_path / "image.tif") +# gt_labels_path = str(im_path / "labels.tif") +# +# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) From 757c8b03fa310d02cf0a238102c417f625a2a80b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:06:43 +0200 Subject: [PATCH 503/577] Added new pre-commit hooks --- .pre-commit-config.yaml | 43 ++++++++++++----------------------------- pyproject.toml | 3 ++- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1e22fb1..da16a3b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,44 +1,25 @@ repos: -# - repo: https://github.com/pre-commit/pre-commit-hooks -# rev: v4.0.1 -# hooks: -# - id: check-docstring-first -# - id: end-of-file-fixer -# - id: trailing-whitespace -# - repo: https://github.com/asottile/setup-cfg-fmt -# rev: v1.20.0 -# hooks: -# - id: setup-cfg-fmt -# - repo: https://github.com/PyCQA/flake8 -# rev: 4.0.1 -# hooks: -# - id: flake8 -# additional_dependencies: [flake8-typing-imports>=1.9.0] -# - repo: https://github.com/myint/autoflake -# rev: v1.4 -# hooks: -# - id: autoflake -# args: ["--in-place", "--remove-all-unused-imports"] -# - repo: https://github.com/PyCQA/isort -# rev: 5.10.1 -# hooks: -# - id: isort + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-docstring-first + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.257' + rev: 'v0.0.262' hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.3.0 hooks: - id: black args: [--line-length=79] -# - repo: https://github.com/asottile/pyupgrade -# rev: v2.29.1 -# hooks: -# - id: pyupgrade -# args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks rev: v0.3.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 5dec250c..d2a2adbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "isort", "black", "ruff", + "pre-commit", ] docs = [ "sphinx", @@ -72,4 +73,4 @@ test = [ "coverage", "tox", "twine", -] \ No newline at end of file +] From 8b0c7a8bc2e745c53ec612a23d8ddd6c4e48fa49 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:36:12 +0200 Subject: [PATCH 504/577] Latest pre-commit hooks --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da16a3b9..7053663e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,13 +2,14 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: check-docstring-first +# - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort + args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' From dcdfaca9c6573daa296a49f9adb42d19516cc33a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:39:57 +0200 Subject: [PATCH 505/577] Run full suite of pre-commit hooks --- README.md | 2 +- docs/res/code/plugin_convert.rst | 15 +++++++++++++ docs/res/code/utils.rst | 4 ++++ napari_cellseg3d/_tests/conftest.py | 1 + napari_cellseg3d/_tests/pytest.ini | 2 +- .../_tests/test_labels_correction.py | 3 ++- napari_cellseg3d/_tests/test_plugin_utils.py | 3 ++- .../code_models/instance_segmentation.py | 3 +-- .../dev_scripts/artefact_labeling.py | 13 ++++++----- .../dev_scripts/correct_labels.py | 22 ++++++++++--------- .../dev_scripts/evaluate_labels.py | 2 +- 11 files changed, 48 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ca8d0931..ece6c6f4 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). Please refer to the documentation for full acknowledgements. diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 25006d0f..03944510 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -28,3 +28,18 @@ ThresholdUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ThresholdUtils :members: __init__ + +Functions +----------------------------------- + +save_folder +***************************************** +.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_folder + +save_layer +**************************************** +.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_layer + +show_result +**************************************** +.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index d9fdcfa2..e90ee7e0 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -62,3 +62,7 @@ denormalize_y load_images ************************************** .. autofunction:: napari_cellseg3d.utils::load_images + +format_Warning +************************************** +.. autofunction:: napari_cellseg3d.utils::format_Warning diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py index 4d4a4007..bbfeff10 100644 --- a/napari_cellseg3d/_tests/conftest.py +++ b/napari_cellseg3d/_tests/conftest.py @@ -1,4 +1,5 @@ import os + import pytest diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini index 814cca2e..45c3be1c 100644 --- a/napari_cellseg3d/_tests/pytest.ini +++ b/napari_cellseg3d/_tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -qt_api=pyqt5 \ No newline at end of file +qt_api=pyqt5 diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index 9d4e7801..c65d7402 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.dev_scripts import artefact_labeling as al from napari_cellseg3d.dev_scripts import correct_labels as cl diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 7ca0555f..5d5ada20 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,6 +1,7 @@ from pathlib import Path -from tifffile import imread + import numpy as np +from tifffile import imread from napari_cellseg3d.code_plugins.plugin_utilities import Utilities from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 6914a9e2..91b2c7c8 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -16,8 +16,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between -from napari_cellseg3d.utils import sphericity_axis from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import sphericity_axis # from napari_cellseg3d.utils import sphericity_volume_area @@ -515,7 +515,6 @@ def __init__(self, widget_parent=None): # self.counters[2].setValue(30) def run_method(self, image): - ################ # For debugging # import napari diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index bf724a46..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,14 +1,17 @@ -import numpy as np -from tifffile import imwrite, imread -import scipy.ndimage as ndimage import os + import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread +from tifffile import imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -from skimage.filters import threshold_otsu """ New code by Yves Paychere diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 50f2e47a..2f079d09 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -1,21 +1,23 @@ -import numpy as np -from tifffile import imread -from tifffile import imwrite -import scipy.ndimage as ndimage -import napari -from pathlib import Path -from functools import partial +import threading import time import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage from napari.qt.threading import thread_worker +from tifffile import imread +from tifffile import imwrite from tqdm import tqdm -import threading + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed -import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels """ New code by Yves Paychère diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index a972fa69..ee9919b6 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -1,7 +1,7 @@ +import napari import numpy as np import pandas as pd from tqdm import tqdm -import napari from napari_cellseg3d.utils import LOGGER as log From 122a733987323be4d24d40cabf701588624ee425 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 24 Mar 2023 17:08:44 +0100 Subject: [PATCH 506/577] Model class refactor --- docs/res/guides/custom_model_template.rst | 27 +-- .../_tests/test_weight_download.py | 2 +- .../code_models/models/model_SwinUNetR.py | 38 +--- .../code_models/models/model_TRAILMAP.py | 21 +- .../code_models/models/model_TRAILMAP_MS.py | 16 +- .../code_models/models/model_test.py | 20 +- napari_cellseg3d/code_models/workers.py | 205 +++++++++--------- .../code_plugins/plugin_model_inference.py | 94 ++++---- .../code_plugins/plugin_model_training.py | 4 +- .../code_plugins/plugin_review.py | 2 +- napari_cellseg3d/config.py | 16 +- napari_cellseg3d/interface.py | 20 +- napari_cellseg3d/utils.py | 41 +--- notebooks/assess_instance.ipynb | 121 ++++++----- requirements.txt | 3 +- setup.cfg | 2 +- 16 files changed, 280 insertions(+), 352 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index a70df29b..ddfb269f 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -6,30 +6,9 @@ Advanced : Declaring a custom model .. warning:: **WIP** : Adding new models is still a work in progress and will likely not work simply by adding the model in the plugin. - Please `file an issue`_ if you would like to add a custom model and we will help you get it working. - -To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder:: - - class ModelTemplate_(ABC): # replace ABC with your PyTorch model class name - use_default_training = True # not needed for now, will serve for WNet training if added to the plugin - weights_file = ( - "model_template.pth" # specify the file name of the weights file only - ) # download URL goes in pretrained_models.json - - @abstractmethod - def __init__( - self, input_image_size, in_channels=1, out_channels=1, **kwargs - ): - """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" - pass - - @abstractmethod - def forward(self, x): - """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" - pass - - .. note:: **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute -.. _file an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues +:: + + diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index 72dc939d..f3ba6654 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import WEIGHTS_DIR +from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR from napari_cellseg3d.code_models.model_workers import WeightsDownloader diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 2d7b5ef6..acd065d8 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -2,39 +2,19 @@ from napari_cellseg3d.utils import LOGGER -logger = LOGGER - - class SwinUNETR_(SwinUNETR): use_default_training = True weights_file = "Swin64_best_metric.pth" - def __init__( - self, - in_channels=1, - out_channels=1, - input_img_size=128, - use_checkpoint=True, - **kwargs, - ): - try: - super().__init__( - input_img_size, - in_channels=in_channels, - out_channels=out_channels, - feature_size=48, - use_checkpoint=use_checkpoint, - **kwargs, - ) - except TypeError as e: - logger.warning(f"Caught TypeError: {e}") - super().__init__( - input_img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - ) + def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 0d9ebace..8c7f3b70 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -2,26 +2,6 @@ from torch import nn -def get_weights_file(): - # model additionally trained on Mathis/Wyss mesoSPIM data - return "TRAILMAP_PyTorch.pth" - # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them - - -def get_net(): - return TRAILMAP(1, 1) - - -def get_output(model, input): - out = model(input) - - return out - - -def get_validation(model, val_inputs): - return model(val_inputs) - - class TRAILMAP(nn.Module): def __init__(self, in_ch, out_ch, *args, **kwargs): super().__init__() @@ -107,6 +87,7 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): out = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) + return out class TRAILMAP_(TRAILMAP): diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 73f842b1..c97b1370 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -3,13 +3,21 @@ logger = LOGGER - class TRAILMAP_MS_(UNet3D): use_default_training = True weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth" - return out + # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + # def get_output(self, input): + # out = self(input) -def get_validation(model, val_inputs): - return model(val_inputs) + # return out + # + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index d34f29e9..1cb52f06 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -3,21 +3,21 @@ class TestModel(nn.Module): - def __init__(self): + use_default_training = True + weights_file = "test.pth" + + def __init__(self, **kwargs): super().__init__() self.linear = nn.Linear(8, 8) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) - def get_net(self): - return self - - def get_output(self, _, input): - return input + # def get_output(self, _, input): + # return input - def get_validation(self, val_inputs): - return val_inputs + # def get_validation(self, val_inputs): + # return val_inputs # if __name__ == "__main__": @@ -25,8 +25,8 @@ def get_validation(self, val_inputs): # model = TestModel() # model.train() # model.zero_grad() -# from napari_cellseg3d.config import WEIGHTS_DIR +# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR # torch.save( # model.state_dict(), -# WEIGHTS_DIR + f"/{get_weights_file()}" +# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" # ) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 4462db41..7fee7e71 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from math import ceil from pathlib import Path +import typing as t import numpy as np import torch @@ -43,7 +44,10 @@ # from napari.qt.threading import thread_worker # threads -from napari.qt.threading import GeneratorWorker, WorkerBaseSignals +from napari.qt.threading import GeneratorWorker + +# from napari.qt.threading import thread_worker +from napari.qt.threading import WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -61,10 +65,19 @@ logger = utils.LOGGER +""" +Writing something to log messages from outside the main thread is rather problematic (plenty of silent crashes...) +so instead, following the instructions in the guides below to have a worker with custom signals, I implemented +a custom worker function.""" + +# FutureReference(): +# https://python-forum.io/thread-31349.html +# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ +# https://napari-staging-site.github.io/guides/stable/threading.html + PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) -VERBOSE_SCHEDULER = True logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") @@ -178,9 +191,9 @@ def safe_extract( class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `on this post`_ + Separate from Worker instances as indicated `here`_ - .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect """ # TODO link ? log_signal = Signal(str) @@ -315,6 +328,21 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Raises an error in main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + + self.log_signal.emit("!" * 20) + self.log_signal.emit("Error occured") + # self.log_signal.emit(msg) + # self.log_signal.emit(str(exception)) + + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + yield exception + # self.quit() + def log_parameters(self): config = self.config @@ -495,32 +523,11 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - def model_output(inputs): - return post_process_transforms( - self.config.model_info.get_model().get_output(model, inputs) - ) - if self.config.keep_on_cpu: dataset_device = "cpu" else: dataset_device = self.config.device - window_size = self.config.sliding_window_config.window_size - window_overlap = self.config.sliding_window_config.window_overlap - - # FIXME - # import sys - - # old_stdout = sys.stdout - # old_stderr = sys.stderr - - # sys.stdout = self.downloader.log_widget - # sys.stdout = self.downloader.log_widget - - dataset_device = ( - "cpu" if self.config.keep_on_cpu else self.config.device - ) - if self.config.sliding_window_config.is_enabled(): window_size = self.config.sliding_window_config.window_size window_size = [window_size, window_size, window_size] @@ -534,27 +541,23 @@ def model_output(inputs): logger.debug(f"inputs type : {inputs.dtype}") try: # outputs = model(inputs) - inputs = utils.remap_image(inputs) def model_output_wrapper(inputs): result = model(inputs) return post_process_transforms(result) - with torch.no_grad(): - outputs = sliding_window_inference( - inputs, - roi_size=window_size, - sw_batch_size=1, # TODO add param - predictor=model_output_wrapper, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - mode="gaussian", - sigma_scale=0.01, - progress=True, - ) + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) except Exception as e: - logger.exception(e) + logger.error(e, exc_info=True) logger.debug("failed to run sliding window inference") self.raise_error(e, "Error during sliding window inference") logger.debug(f"Inference output shape: {outputs.shape}") @@ -565,9 +568,11 @@ def model_output_wrapper(inputs): if post_process: out = np.array(out).astype(np.float32) out = np.squeeze(out) - return out + return out + else: + return out except Exception as e: - logger.exception(e) + logger.error(e, exc_info=True) self.raise_error(e, "Error during sliding window inference") # sys.stdout = old_stdout # sys.stderr = old_stderr @@ -766,14 +771,8 @@ def run_crf(self, image, labels, image_id=0): return None def stats_csv(self, instance_labels): - try: - if self.config.compute_stats: - if len(instance_labels.shape) == 4: - stats = [volume_stats(c) for c in instance_labels] - else: - stats = [volume_stats(instance_labels)] - else: - stats = None + if self.config.compute_stats: + stats = volume_stats(instance_labels) return stats except ValueError as e: self.log(f"Error occurred during stats computing : {e}") @@ -858,47 +857,38 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config - if Path(weights_config.path).suffix == ".pt": - self.log("Instantiating PyTorch jit model...") - model = torch.jit.load(weights_config.path) + + # try: + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[128, 128, 128], + ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") # try: - elif Path(weights_config.path).suffix == ".onnx": - self.log("Instantiating ONNX model...") - model = ONNXModelWrapper(weights_config.path) - else: # assume is .pth - self.log("Instantiating model...") - model = model_class( # FIXME test if works - input_img_size=[dims, dims, dims], - device=self.config.device, - num_classes=self.config.model_info.num_classes, + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, ) - # try: - model = model.to(self.config.device) - # except Exception as e: - # self.raise_error(e, "Issue loading model to device") - # logger.debug(f"model : {model}") - if model is None: - raise ValueError("Model is None") - # try: - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path - else: - self.downloader.download_weights( - model_name, - model_class.weights_file, - ) - weights = str( - PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) - ) - - model.load_state_dict( # note that this is redefined in WNet_ - torch.load( - weights, - map_location=self.config.device, - ) + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - self.log("Done") + model.load_state_dict( + torch.load( + weights, + map_location=self.config.device, + ) + ) + self.log("Done") # except Exception as e: # self.raise_error(e, "Issue loading weights") # except Exception as e: @@ -986,7 +976,7 @@ def inference(self): model.to("cpu") # self.quit() except Exception as e: - logger.exception(e) + logger.error(e, exc_info=True) self.raise_error(e, "Inference failed") self.quit() finally: @@ -1078,6 +1068,14 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Sends an error to main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + self.quit() + def log_parameters(self): self.log("-" * 20) self.log("Parameters summary :\n") @@ -1209,7 +1207,10 @@ def train(self): do_sampling = self.config.sampling - size = self.config.sample_size if do_sampling else check + if do_sampling: + size = self.config.sample_size + else: + size = check model = model_class( # FIXME check if correct input_img_size=utils.get_padding_dim(size), use_checkpoint=True @@ -1417,7 +1418,7 @@ def train(self): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") - logger.exception(e) + logger.error(e, exc_info=True) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" @@ -1512,18 +1513,18 @@ def train(self): val_data["image"].to(device), val_data["label"].to(device), ) + self.log("Performing validation...") try: - with torch.no_grad(): - val_outputs = sliding_window_inference( - val_inputs, - roi_size=size, - sw_batch_size=self.config.batch_size, - predictor=model, - overlap=0.25, - sw_device=self.config.device, - device=self.config.device, - progress=False, - ) + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) except Exception as e: self.raise_error(e, "Error during validation") logger.debug( diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 75336f26..290a00f5 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -168,9 +168,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex( - self._default_window_size - ) # set to 64 by default + self.window_size_choice.setCurrentIndex(3) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -831,23 +829,27 @@ def on_yield(self, result: InferenceResult): ): zoom = self.worker_config.post_process_config.zoom.zoom_values - if ( - self.config.show_results - and image_id <= self.config.show_results_count - ): - zoom = self.worker_config.post_process_config.zoom.zoom_values + viewer.dims.ndisplay = 3 + viewer.scale_bar.visible = True + + if self.config.show_original and result.original is not None: + viewer.add_image( + result.original, + colormap="inferno", + name=f"original_{image_id}", + scale=zoom, + opacity=0.7, + ) out_colormap = "twilight" if self.worker_config.post_process_config.thresholding.enabled: out_colormap = "turbo" - if self.config.show_original and result.original is not None: viewer.add_image( - result.original, - colormap="inferno", - name=f"original_{image_id}", - scale=zoom, - opacity=0.7, + result.result, + colormap=out_colormap, + name=f"pred_{image_id}_{model_name}", + opacity=0.8, ) if result.crf_results is not None: logger.debug( @@ -866,52 +868,38 @@ def on_yield(self, result: InferenceResult): self.worker_config.post_process_config.instance.method.name ) - out_colormap = "twilight" - if self.worker_config.post_process_config.thresholding.enabled: - out_colormap = "turbo" - - viewer.add_image( - result.result, - colormap=out_colormap, - name=f"pred_{image_id}_{model_name}", - opacity=0.8, - ) - - if result.instance_labels is not None: - labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + if result.instance_labels is not None: + labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - if result.stats is not None and isinstance( - result.stats, list - ): - log.debug(f"len stats : {len(result.stats)}") + stats = result.stats - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if self.worker_config.compute_stats and stats is not None: + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances in channel {i} : {stats.number_objects[0]}" - ) + self.log.print_and_log( + f"Number of instances : {stats.number_objects}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index b4ca9848..cc34a161 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1056,7 +1056,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.exception(e) + logger.error(e, exc_info=True) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1224,7 +1224,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.exception(e) + logger.error(e, exc_info=True) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 77149208..23bcb4c5 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -400,7 +400,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.exception(e) + logger.error(e, exc_info=True) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 89d88c52..c30ba6d9 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -9,12 +9,11 @@ from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP -from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet -from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR -from napari_cellseg3d.code_models.models import ( - model_TRAILMAP_MS as TRAILMAP_MS, -) -from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ +from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ +from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ +from napari_cellseg3d.code_models.models.model_VNet import VNet_ + from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -28,12 +27,10 @@ # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS_, "SwinUNetR": SwinUNETR_, - "WNet": WNet_, # "test" : DO NOT USE, reserved for testing } - -WEIGHTS_DIR = str( +PRETRAINED_WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") ) @@ -73,7 +70,6 @@ class ModelInfo: Args: name (str): name of the model model_input_size (Optional[List[int]]): input size of the model - num_classes (int): number of classes for the model """ name: str = next(iter(MODEL_LIST)) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 50f19269..321a9c67 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -317,6 +317,22 @@ def error(self, error, msg=None): finally: self.lock.release() + def error(self, error, msg=None): + """Show exception and message from another thread""" + self.lock.acquire() + try: + logger.error(error, exc_info=True) + if msg is not None: + self.print_and_log(f"{msg} : {error}", printing=False) + else: + self.print_and_log( + f"Excepetion caught in another thread : {error}", + printing=False, + ) + raise error + finally: + self.lock.release() + ############## # UI elements @@ -1237,8 +1253,8 @@ def open_folder_dialog( default_path = utils.parse_default_path(possible_paths) logger.info(f"Default : {default_path}") - return QFileDialog.getExistingDirectory( - widget, "Open directory", default_path # + "/.." + filenames = QFileDialog.getExistingDirectory( + widget, "Open directory", default_path + "/.." ) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 31ea1a65..b64adcb9 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,12 +1,8 @@ import logging from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Union - -import napari import numpy as np - -# from dask import delayed +from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian from tifffile import imread, imwrite @@ -120,7 +116,7 @@ def __call__(cls, *args, **kwargs): # if filename == "tif": # return True # def read(self, data, **kwargs): -# return imread(data) +# return tfl_imread(data) # # def get_data(self, data): # return data, {} @@ -208,39 +204,6 @@ def dice_coeff(y_true, y_pred): ) -def correct_rotation(image): - """Rotates the exes 0 and 2 in [DHW] section of image array""" - extra_dims = len(image.shape) - 3 - return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) - - -def normalize_max(image): - """Normalizes an image using the max and min value""" - shape = image.shape - image = image.flatten() - image = (image - image.min()) / (image.max() - image.min()) - image = image.reshape(shape) - return image - - -def remap_image( - image: Union["np.ndarray", "torch.Tensor"], - new_max=100, - new_min=0, - prev_max=None, - prev_min=None, -): - """Normalizes a numpy array or Tensor using the max and min value""" - shape = image.shape - image = image.flatten() - im_max = prev_max if prev_max is not None else image.max() - im_min = prev_min if prev_min is not None else image.min() - image = (image - im_min) / (im_max - im_min) - image = image * (new_max - new_min) + new_min - image = image.reshape(shape) - return image - - def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b8810301..59ae05c1 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -44,10 +44,20 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" + ] + }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -57,14 +67,15 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"pred.tif\")\n", + "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = resize(prediction, zoom)\n", + "prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -85,7 +96,7 @@ { "data": { "text/plain": [ - "0.5817600487210719" + "0.7538125057831502" ] }, "execution_count": 4, @@ -96,9 +107,15 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", + "semantic_gt = to_semantic(gt_labels_resized.copy())\n", + "semantic_pred = to_semantic(prediction_resized.copy())\n", + "\n", + "viewer.add_image(semantic_gt, colormap='bop blue')\n", + "viewer.add_image(semantic_pred, colormap='red')\n", + "\n", "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", + " semantic_gt,\n", + " prediction_resized\n", ")" ] }, @@ -171,7 +188,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -198,24 +215,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" + "2023-03-24 14:23:13,590 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", + "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", + "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" ] }, { @@ -228,15 +245,15 @@ { "data": { "text/plain": [ - "(65,\n", - " 46,\n", - " 13,\n", - " 12,\n", - " 0.9042297461803984,\n", - " 0.8512759824829847,\n", - " 0.9136359067720888,\n", - " 0.8728146835389444,\n", - " 1.0)" + "(63,\n", + " 45,\n", + " 16,\n", + " 16,\n", + " 0.819027731148306,\n", + " 0.8401649108992161,\n", + " 0.83609908334452,\n", + " 0.8066092803671974,\n", + " 0.98)" ] }, "execution_count": 9, @@ -262,24 +279,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" + "2023-03-24 14:23:13,732 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", + "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", + "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -292,15 +309,15 @@ { "data": { "text/plain": [ - "(68,\n", - " 43,\n", + "(77,\n", + " 34,\n", " 13,\n", - " 10,\n", - " 0.8856947654346812,\n", - " 0.8747475859219296,\n", - " 0.9187750563205743,\n", - " 0.862012598981557,\n", - " 1.0)" + " 9,\n", + " 0.728461197681457,\n", + " 0.8885669859686413,\n", + " 0.8950588507577087,\n", + " 0.7472814623489069,\n", + " 0.878614359974009)" ] }, "execution_count": 10, @@ -338,7 +355,7 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", + "voronoi = voronoi_otsu(prediction_resized, 0.6, outline_sigma=0.7)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", @@ -414,24 +431,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" + "2023-03-24 14:23:14,241 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", + "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", + "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", + "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" ] }, { @@ -444,15 +461,15 @@ { "data": { "text/plain": [ - "(99,\n", - " 12,\n", - " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" + "(102,\n", + " 8,\n", + " 14,\n", + " 16,\n", + " 0.708505702558253,\n", + " 0.8832633585884945,\n", + " 0.9759871495093808,\n", + " 0.6670483272595948,\n", + " 0.8653680990771155)" ] }, "execution_count": 15, diff --git a/requirements.txt b/requirements.txt index a7dd1570..3ca0e56d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,13 +16,12 @@ QtPy opencv-python>=4.5.5 pre-commit pyclesperanto-prototype>=0.22.0 -pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 ruff tifffile>=2022.2.9 torch>=1.11 -monai[nibabel,einops]>=1.0.1 +monai[nibabel,einops,tifffile]>=1.0.1 pillow scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index 2420dd1c..f3294b60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,einops]>=1.0.1 + monai[nibabel,einops,tifffile]>=1.0.1 itk tqdm nibabel From 926125d881dd551c35cada43d9df1e2a0841ccf2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 29 Mar 2023 09:55:58 +0200 Subject: [PATCH 507/577] Added LR scheduler in training - Added ReduceLROnPlateau with params in training - Updated training guide - Minor UI attribute refactor - black --- napari_cellseg3d/_tests/fixtures.py | 2 +- .../_tests/test_plugin_inference.py | 42 +++++-------------- .../code_models/instance_segmentation.py | 2 +- napari_cellseg3d/code_models/workers.py | 14 ++----- .../code_plugins/plugin_model_training.py | 2 + napari_cellseg3d/interface.py | 10 ++--- 6 files changed, 23 insertions(+), 49 deletions(-) diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index da34ae9b..ab8d0cb2 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -15,7 +15,7 @@ def print_and_log(self, text, printing=None): print(text) def warn(self, warning): - logger.warning(warning) + warnings.warn(warning) def error(self, e): raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index fbeb9943..66c50fba 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -28,34 +28,14 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - widget.model_choice.setCurrentText("WNet") - widget._restrict_window_size_for_model() - assert widget.window_infer_box.isChecked() - assert widget.window_size_choice.currentText() == "64" - - test_model_name = "test" - MODEL_LIST[test_model_name] = TestModel - widget.model_choice.addItem(test_model_name) - widget.model_choice.setCurrentText(test_model_name) - - widget.worker_config = widget._set_worker_config() - assert widget.worker_config is not None - assert widget.model_info is not None - worker = widget._create_worker_from_config(widget.worker_config) - - assert worker.config is not None - assert worker.config.model_info is not None - worker.config.layer = viewer.layers[0].data - worker.config.post_process_config.instance.enabled = True - worker.config.post_process_config.instance.method = ( - INSTANCE_SEGMENTATION_METHOD_LIST["Watershed"]() - ) - - assert worker.config.layer is not None - worker.log_parameters() - - res = next(worker.inference()) - assert isinstance(res, InferenceResult) - assert res.result.shape == (8, 8, 8) - - widget.on_yield(res) + MODEL_LIST["test"] = TestModel + widget.model_choice.addItem("test") + widget.setCurrentIndex(-1) + + # widget.start() # takes too long on Github Actions + # assert widget.worker is not None + + # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: + # blocker.connect(widget.worker.errored) + + #### assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 91b2c7c8..c97d7ea7 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -74,7 +74,7 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label="", parent=None), + ui.DoubleIncrementCounter(text_label="", parent=None), ) self.counters.append(getattr(self, widget)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 7fee7e71..c1458a6d 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -78,6 +78,7 @@ PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) +VERBOSE_SCHEDULER = True logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") @@ -1376,23 +1377,14 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) - - factor = self.config.scheduler_factor - if factor >= 1.0: - self.log(f"Warning : scheduler factor is {factor} >= 1.0") - self.log("Setting it to 0.5") - factor = 0.5 - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode="min", - factor=factor, + factor=self.config.scheduler_factor, patience=self.config.scheduler_patience, verbose=VERBOSE_SCHEDULER, ) - dice_metric = DiceMetric( - include_background=False, reduction="mean" - ) + dice_metric = DiceMetric(include_background=True, reduction="mean") best_metric = -1 best_metric_epoch = -1 diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index cc34a161..eed04a57 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -849,6 +849,8 @@ def start(self): max_epochs=self.epoch_choice.value(), loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.value(), validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 321a9c67..d25f02d3 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -544,10 +544,10 @@ def __init__( def set_visibility(self, visible: bool): self.container.setVisible(visible) self.setVisible(visible) - self.text_label.setVisible(visible) + self.label.setVisible(visible) def _build_container(self): - if self.text_label is not None: + if self.label is not None: add_widgets( self.container.layout, [ @@ -1093,7 +1093,7 @@ def __init__( step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed - label (Optional[str]): if provided, creates a label with the chosen title to use with the counter + text_label (Optional[str]): if provided, creates a label with the chosen title to use with the counter """ super().__init__(parent) @@ -1101,8 +1101,8 @@ def __init__( self.layout = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) self.valueChanged.connect(self._update_step) def _update_step(self): From e8a3e955529e972a6fd64cc377660e051420043c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 31 Mar 2023 15:45:00 +0200 Subject: [PATCH 508/577] Update assess_instance.ipynb --- notebooks/assess_instance.ipynb | 162 ++++++++++++++++++++------------ 1 file changed, 101 insertions(+), 61 deletions(-) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index 59ae05c1..0dec4543 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -44,20 +44,10 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", - "In the future `np.bool` will be defined as the corresponding NumPy scalar.\n" - ] - }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -67,15 +57,16 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", - "prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "# prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", "prediction = imread(prediction_path)\n", "gt_labels = imread(gt_labels_path)\n", "\n", "zoom = (1 / 5, 1, 1)\n", - "# prediction_resized = resize(prediction, zoom)\n", - "prediction_resized = prediction # for trailmap\n", + "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -96,7 +87,7 @@ { "data": { "text/plain": [ - "0.7538125057831502" + "0.8592223181276479" ] }, "execution_count": 4, @@ -188,7 +179,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -215,24 +206,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,590 - Mapping labels...\n" + "2023-03-31 15:37:19,775 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 103/103 [00:00<00:00, 2689.96it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3699.66it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,631 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,634 - Percent of non-fused neurons found: 50.40%\n", - "2023-03-24 14:23:13,635 - Percent of fused neurons found: 36.00%\n", - "2023-03-24 14:23:13,635 - Overall percent of neurons found: 86.40%\n" + "2023-03-31 15:37:19,812 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,815 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-31 15:37:19,816 - Percent of fused neurons found: 36.80%\n", + "2023-03-31 15:37:19,817 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -245,15 +236,15 @@ { "data": { "text/plain": [ - "(63,\n", - " 45,\n", - " 16,\n", - " 16,\n", - " 0.819027731148306,\n", - " 0.8401649108992161,\n", - " 0.83609908334452,\n", - " 0.8066092803671974,\n", - " 0.98)" + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" ] }, "execution_count": 9, @@ -279,24 +270,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,732 - Mapping labels...\n" + "2023-03-31 15:37:19,919 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 5221.10it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3992.79it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:13,761 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:13,774 - Percent of non-fused neurons found: 61.60%\n", - "2023-03-24 14:23:13,775 - Percent of fused neurons found: 27.20%\n", - "2023-03-24 14:23:13,776 - Overall percent of neurons found: 88.80%\n" + "2023-03-31 15:37:19,949 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,952 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-31 15:37:19,953 - Percent of fused neurons found: 34.40%\n", + "2023-03-31 15:37:19,953 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -309,15 +300,15 @@ { "data": { "text/plain": [ - "(77,\n", - " 34,\n", + "(68,\n", + " 43,\n", " 13,\n", - " 9,\n", - " 0.728461197681457,\n", - " 0.8885669859686413,\n", - " 0.8950588507577087,\n", - " 0.7472814623489069,\n", - " 0.878614359974009)" + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" ] }, "execution_count": 10, @@ -343,6 +334,40 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:37:21,076 - build program: kernel 'gaussian_blur_separable_3d' was part of a lengthy source build resulting from a binary cache miss (0.88 s)\n", + "2023-03-31 15:37:21,514 - build program: kernel 'copy_3d' was part of a lengthy source build resulting from a binary cache miss (0.42 s)\n", + "2023-03-31 15:37:22,021 - build program: kernel 'detect_maxima_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:22,642 - build program: kernel 'minimum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.59 s)\n", + "2023-03-31 15:37:23,117 - build program: kernel 'minimum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:23,651 - build program: kernel 'minimum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,188 - build program: kernel 'maximum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,801 - build program: kernel 'maximum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.60 s)\n", + "2023-03-31 15:37:25,263 - build program: kernel 'maximum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:25,766 - build program: kernel 'histogram_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:26,256 - build program: kernel 'sum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:26,699 - build program: kernel 'greater_constant_3d' was part of a lengthy source build resulting from a binary cache miss (0.43 s)\n", + "2023-03-31 15:37:27,158 - build program: kernel 'binary_and_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:27,635 - build program: kernel 'add_image_and_scalar_3d' was part of a lengthy source build resulting from a binary cache miss (0.47 s)\n", + "2023-03-31 15:37:28,128 - build program: kernel 'set_nonzero_pixels_to_pixelindex' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:28,580 - build program: kernel 'set_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:29,076 - build program: kernel 'nonzero_minimum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:29,551 - build program: kernel 'set_2d' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:30,035 - build program: kernel 'flag_existing_labels' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:30,544 - build program: kernel 'set_column_2d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:31,033 - build program: kernel 'sum_reduction_x' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:31,572 - build program: kernel 'block_enumerate' was part of a lengthy source build resulting from a binary cache miss (0.53 s)\n", + "2023-03-31 15:37:32,094 - build program: kernel 'replace_intensities' was part of a lengthy source build resulting from a binary cache miss (0.51 s)\n", + "2023-03-31 15:37:32,685 - build program: kernel 'add_images_weighted_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:33,256 - build program: kernel 'onlyzero_overwrite_maximum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.56 s)\n", + "2023-03-31 15:37:33,845 - build program: kernel 'onlyzero_overwrite_maximum_diamond_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:34,369 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:34,888 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n" + ] + }, { "data": { "text/plain": [ @@ -431,24 +456,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,241 - Mapping labels...\n" + "2023-03-31 15:37:36,854 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 2376.22it/s]" + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 611.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-24 14:23:14,301 - Calculating the number of neurons not found...\n", - "2023-03-24 14:23:14,303 - Percent of non-fused neurons found: 81.60%\n", - "2023-03-24 14:23:14,304 - Percent of fused neurons found: 6.40%\n", - "2023-03-24 14:23:14,305 - Overall percent of neurons found: 88.00%\n" + "2023-03-31 15:37:37,087 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:37,098 - Percent of non-fused neurons found: 87.20%\n", + "2023-03-31 15:37:37,104 - Percent of fused neurons found: 1.60%\n", + "2023-03-31 15:37:37,114 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -461,15 +486,15 @@ { "data": { "text/plain": [ - "(102,\n", + "(109,\n", + " 2,\n", + " 13,\n", " 8,\n", - " 14,\n", - " 16,\n", - " 0.708505702558253,\n", - " 0.8832633585884945,\n", - " 0.9759871495093808,\n", - " 0.6670483272595948,\n", - " 0.8653680990771155)" + " 0.8285521200005869,\n", + " 0.8809251900364068,\n", + " 0.9838709677419355,\n", + " 0.782258064516129,\n", + " 1.0)" ] }, "execution_count": 15, @@ -490,10 +515,25 @@ "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:40:34,683 - No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'\n" + ] + } + ], "source": [ "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -512,7 +552,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" } }, "nbformat": 4, From 954cc5429f8d82e5a6fd332d8e2897429dc70efe Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 11:09:30 +0200 Subject: [PATCH 509/577] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df43b4fa..df67a187 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png +notebooks/instance_test.ipynb *.prof #include test data From 4fc31ba6b3c8e9c7ea4928f60d212c65d3a1d7c3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 19 Apr 2023 14:27:21 +0200 Subject: [PATCH 510/577] Started adding WNet --- .../code_models/models/model_SwinUNetR.py | 29 +++-- .../code_models/models/model_TRAILMAP_MS.py | 14 ++- .../code_models/models/model_WNet.py | 31 ++--- .../code_models/models/wnet/crf.py | 112 ++++++++++++++++++ .../code_models/models/wnet/model.py | 96 ++++++--------- .../code_models/models/wnet/soft_Ncuts.py | 12 +- napari_cellseg3d/code_models/workers.py | 4 +- napari_cellseg3d/config.py | 3 + 8 files changed, 203 insertions(+), 98 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index acd065d8..f582e367 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,7 @@ from monai.networks.nets import SwinUNETR +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER from napari_cellseg3d.utils import LOGGER @@ -7,14 +10,24 @@ class SwinUNETR_(SwinUNETR): weights_file = "Swin64_best_metric.pth" def __init__(self, input_img_size, use_checkpoint=True, **kwargs): - super().__init__( - input_img_size, - in_channels=1, - out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - **kwargs - ) + try: + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs, + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index c97b1370..66d61201 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -3,6 +3,8 @@ logger = LOGGER +logger = LOGGER + class TRAILMAP_MS_(UNet3D): use_default_training = True weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth" @@ -10,9 +12,15 @@ class TRAILMAP_MS_(UNet3D): # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) def __init__(self, in_channels=1, out_channels=1, **kwargs): - super().__init__( - in_channels=in_channels, out_channels=out_channels, **kwargs - ) + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError as e: + logger.warn(f"Caught TypeError: {e}") + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) # def get_output(self, input): # out = self(input) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 62142e73..63a91b10 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,8 +1,7 @@ -# local -from napari_cellseg3d.code_models.models.wnet.model import WNet_encoder +from napari_cellseg3d.code_models.models.wnet.model import WNet -class WNet_(WNet_encoder): +class WNet_(WNet): use_default_training = False weights_file = "wnet.pth" @@ -12,7 +11,7 @@ def __init__( out_channels=1, num_classes=2, device="cpu", - **kwargs, + **kwargs ): super().__init__( device=device, @@ -21,22 +20,8 @@ def __init__( num_classes=num_classes, ) - # def train(self: T, mode: bool = True) -> T: - # raise NotImplementedError("Training not implemented for WNet") - - # def forward(self, x): - # """Forward ENCODER pass of the W-Net model. - # Done this way to allow inference on the encoder only when called by sliding_window_inference. - # """ - # return self.forward_encoder(x) - # # enc = self.forward_encoder(x) - # # return self.forward_decoder(enc) - - def load_state_dict(self, state_dict, strict=False): - """Load the model state dict for inference, without the decoder weights.""" - encoder_checkpoint = state_dict.copy() - for k in state_dict: - if k.startswith("decoder"): - encoder_checkpoint.pop(k) - # print(encoder_checkpoint.keys()) - super().load_state_dict(encoder_checkpoint, strict=strict) + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + # dec = self.forward_decoder(enc) + return enc diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py new file mode 100644 index 00000000..ca11fba2 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -0,0 +1,112 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" + +import numpy as np +import pydensecrf.densecrf as dcrf +from pydensecrf.utils import ( + unary_from_softmax, + create_pairwise_gaussian, + create_pairwise_bilateral, +) + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 2900b89c..585ea0dd 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -16,19 +16,6 @@ ] -class WNet_encoder(nn.Module): - """WNet with encoder only.""" - - def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): - super().__init__() - self.device = device - self.encoder = UNet(device, in_channels, num_classes, encoder=True) - - def forward(self, x): - """Forward pass of the W-Net model.""" - return self.encoder(x) - - class WNet(nn.Module): """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. The model performs unsupervised segmentation of 3D images. @@ -49,43 +36,30 @@ def forward(self, x): def forward_encoder(self, x): """Forward pass of the encoder part of the W-Net model.""" - return self.encoder(x) + enc = self.encoder(x) + return enc def forward_decoder(self, enc): """Forward pass of the decoder part of the W-Net model.""" - return self.decoder(enc) + dec = self.decoder(enc) + return dec class UNet(nn.Module): """Half of the W-Net model, based on the U-Net architecture.""" - def __init__( - self, device, in_channels, out_channels, encoder=True, dropout=0.65 - ): + def __init__(self, device, in_channels, out_channels, encoder=True): super(UNet, self).__init__() self.device = device - self.max_pool = nn.MaxPool3d(2) - self.in_b = InBlock(device, in_channels, 64, dropout=dropout) - self.conv1 = Block(device, 64, 128, dropout=dropout) - self.conv2 = Block(device, 128, 256, dropout=dropout) - self.conv3 = Block(device, 256, 512, dropout=dropout) - self.bot = Block(device, 512, 1024, dropout=dropout) - self.deconv1 = Block(device, 1024, 512, dropout=dropout) - self.conv_trans1 = nn.ConvTranspose3d( - 1024, 512, 2, stride=2, device=self.device - ) - self.deconv2 = Block(device, 512, 256, dropout=dropout) - self.conv_trans2 = nn.ConvTranspose3d( - 512, 256, 2, stride=2, device=self.device - ) - self.deconv3 = Block(device, 256, 128, dropout=dropout) - self.conv_trans3 = nn.ConvTranspose3d( - 256, 128, 2, stride=2, device=self.device - ) - self.out_b = OutBlock(device, 128, out_channels, dropout=dropout) - self.conv_trans_out = nn.ConvTranspose3d( - 128, 64, 2, stride=2, device=self.device - ) + self.in_b = InBlock(device, in_channels, 64) + self.conv1 = Block(device, 64, 128) + self.conv2 = Block(device, 128, 256) + self.conv3 = Block(device, 256, 512) + self.bot = Block(device, 512, 1024) + self.deconv1 = Block(device, 1024, 512) + self.deconv2 = Block(device, 512, 256) + self.deconv3 = Block(device, 256, 128) + self.out_b = OutBlock(device, 128, out_channels) self.sm = nn.Softmax(dim=1).to(device) self.encoder = encoder @@ -93,15 +67,17 @@ def __init__( def forward(self, x): """Forward pass of the U-Net model.""" in_b = self.in_b(x.to(self.device)) - c1 = self.conv1(self.max_pool(in_b)) - c2 = self.conv2(self.max_pool(c1)) - c3 = self.conv3(self.max_pool(c2)) - x = self.bot(self.max_pool(c3)) + c1 = self.conv1(nn.MaxPool3d(2)(in_b)) + c2 = self.conv2(nn.MaxPool3d(2)(c1)) + c3 = self.conv3(nn.MaxPool3d(2)(c2)) + x = self.bot(nn.MaxPool3d(2)(c3)) x = self.deconv1( torch.cat( [ c3, - self.conv_trans1(x), + nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + )(x), ], dim=1, ) @@ -110,7 +86,9 @@ def forward(self, x): torch.cat( [ c2, - self.conv_trans2(x), + nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + )(x), ], dim=1, ) @@ -119,7 +97,9 @@ def forward(self, x): torch.cat( [ c1, - self.conv_trans3(x), + nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + )(x), ], dim=1, ) @@ -128,7 +108,9 @@ def forward(self, x): torch.cat( [ in_b, - self.conv_trans_out(x), + nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + )(x), ], dim=1, ) @@ -141,17 +123,17 @@ def forward(self, x): class InBlock(nn.Module): """Input block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, dropout=0.65): + def __init__(self, device, in_channels, out_channels): super(InBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=dropout), + nn.Dropout(p=0.65), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=dropout), + nn.Dropout(p=0.65), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -163,19 +145,19 @@ def forward(self, x): class Block(nn.Module): """Basic block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, dropout=0.65): + def __init__(self, device, in_channels, out_channels): super(Block, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), nn.Conv3d(in_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=dropout), + nn.Dropout(p=0.65), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.Conv3d(out_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=dropout), + nn.Dropout(p=0.65), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -187,17 +169,17 @@ def forward(self, x): class OutBlock(nn.Module): """Output block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, dropout=0.65): + def __init__(self, device, in_channels, out_channels): super(OutBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=dropout), + nn.Dropout(p=0.65), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=dropout), + nn.Dropout(p=0.65), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, out_channels, 1, device=device), ).to(device) diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 938292c2..6a625355 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -1,15 +1,15 @@ """ Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. -The implementation was adapted and approximated to reduce computational and memory cost. +The implementation was adapted and approximated to reduce computational and memory cost. This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. """ import math - -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F + +import numpy as np from scipy.stats import norm __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" @@ -56,7 +56,7 @@ def __init__(self, data_shape, device, o_i, o_x, radius=None): # self.distances, self.indexes = self.get_distances() """ - + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) @@ -206,7 +206,6 @@ def forward(self, labels, inputs): return torch.add(torch.neg(loss), K) """ - return None def gaussian_kernel(self, radius, sigma): """Computes the Gaussian kernel. @@ -349,4 +348,5 @@ def get_weights(self, inputs): 1, 1, self.W_X.shape[0], self.W_X.shape[1] ) # (1, 1, H*W*D, H*W*D) - return torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) + W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) + return W diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c1458a6d..e355c6c8 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -862,7 +862,9 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=[128, 128, 128], + input_img_size=dims, + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) # try: model = model.to(self.config.device) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index c30ba6d9..6c749096 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -13,6 +13,7 @@ from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ +from napari_cellseg3d.code_models.models.model_WNet import WNet_ from napari_cellseg3d.utils import LOGGER @@ -27,6 +28,7 @@ # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS_, "SwinUNetR": SwinUNETR_, + "WNet": WNet_, # "test" : DO NOT USE, reserved for testing } @@ -70,6 +72,7 @@ class ModelInfo: Args: name (str): name of the model model_input_size (Optional[List[int]]): input size of the model + num_classes (int): number of classes for the model """ name: str = next(iter(MODEL_LIST)) From 8f0ca84bc65f798ca9f2199f3defeddf6360a495 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 20 Apr 2023 11:12:59 +0200 Subject: [PATCH 511/577] Specify no grad in inference --- napari_cellseg3d/code_models/workers.py | 42 +++++++++++++------------ 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index e355c6c8..680ce0ac 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -547,16 +547,17 @@ def model_output_wrapper(inputs): result = model(inputs) return post_process_transforms(result) - outputs = sliding_window_inference( - inputs, - roi_size=window_size, - sw_batch_size=1, # TODO add param - predictor=model_output_wrapper, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) + with torch.no_grad(): + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + progress=True, + ) except Exception as e: logger.error(e, exc_info=True) logger.debug("failed to run sliding window inference") @@ -1509,16 +1510,17 @@ def train(self): ) self.log("Performing validation...") try: - val_outputs = sliding_window_inference( - val_inputs, - roi_size=size, - sw_batch_size=self.config.batch_size, - predictor=model, - overlap=0.25, - sw_device=self.config.device, - device=self.config.device, - progress=True, - ) + with torch.no_grad(): + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=True, + ) except Exception as e: self.raise_error(e, "Error during validation") logger.debug( From 8a36ebf875563f5fd09e22f92bdfbf850a012eca Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 22 Apr 2023 14:12:32 +0200 Subject: [PATCH 512/577] First functional WNet inference, no CRF --- .../code_models/models/model_WNet.py | 3 +- napari_cellseg3d/code_models/workers.py | 52 ++++++++++------- .../code_plugins/plugin_model_inference.py | 57 +++++++++++-------- 3 files changed, 66 insertions(+), 46 deletions(-) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 63a91b10..dffa3b44 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,8 @@ def __init__( ) def forward(self, x): - """Forward pass of the W-Net model.""" + """Forward ENCODER pass of the W-Net model. + Done this way to allow inference on the encoder only when called by sliding_window_inference.""" enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 680ce0ac..11da820f 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -256,7 +256,6 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None - crf_results: np.array = None stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -599,15 +598,10 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - - if semantic_labels is not None: - semantic_labels = utils.correct_rotation(semantic_labels) - if crf_results is not None: - crf_results = utils.correct_rotation(crf_results) - if instance_labels is not None: - instance_labels = utils.correct_rotation( - instance_labels - ) # TODO(cyril) check if correct + total_dims = len(semantic_labels.shape) - 3 + semantic_labels = np.swapaxes( + semantic_labels, 0 + total_dims, 2 + total_dims + ) return InferenceResult( image_id=i + 1, @@ -655,7 +649,7 @@ def save_image( filetype = self.config.filetype else: original_filename = "_" - filetype = ".tif" + filetype = "" time = utils.get_date_time() @@ -666,7 +660,7 @@ def save_image( + f"Prediction_{i+1}" + original_filename + self.config.model_info.name - + f"_{time}" + + f"_{time}_" + filetype ) try: @@ -692,18 +686,31 @@ def aniso_transform(self, image): else: return image - def instance_seg(self, to_instance, image_id=0, original_filename="layer"): + def instance_seg( + self, to_instance, image_id=0, original_filename="layer", channel=None + ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method instance_labels = method.run_method(image=to_instance) + if channel is not None: + channel_id = f"_{channel}" + else: + channel_id = "" + + if self.config.filetype == "": + filetype = "" + else: + filetype = "_" + self.config.filetype + instance_filepath = ( self.config.results_path + "/" + f"Instance_seg_labels_{image_id}_" + original_filename + + channel_id + "_" + self.config.model_info.name + f"_{utils.get_date_time()}" @@ -795,18 +802,21 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, stats = self.get_instance_result( - semantic_labels=out, from_layer=True - ) + instance_labels_results = [] + stats_results = [] - crf_results = self.run_crf(image, out) if self.config.use_crf else None + for channel in out: + instance_labels, stats = self.get_instance_result( + channel, from_layer=True + ) + instance_labels_results.append(instance_labels) + stats_results.append(stats) return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels, - crf_results=crf_results, + instance_labels=instance_labels_results, from_layer=True, - stats=stats, + stats=stats_results, ) # @thread_worker(connect={"errored": self.raise_error}) @@ -863,7 +873,7 @@ def inference(self): # try: self.log("Instantiating model...") model = model_class( # FIXME test if works - input_img_size=dims, + input_img_size=[dims, dims, dims], device=self.config.device, num_classes=self.config.model_info.num_classes, ) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 290a00f5..934f51e7 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -869,37 +869,46 @@ def on_yield(self, result: InferenceResult): ) if result.instance_labels is not None: - labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + for i, labels in enumerate(result.instance_labels): + # labels = result.instance_labels + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(labels, name=name) - stats = result.stats + from napari_cellseg3d.utils import LOGGER as log - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + log.debug(f"len stats : {len(result.stats)}") - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + for i, stats in enumerate(result.stats): + # stats = result.stats - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) + + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) + + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) except Exception as e: self.on_error(e) From fbfc513a3155156d7996825c8f79702578256396 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 10:48:12 +0200 Subject: [PATCH 513/577] Create test_models.py --- napari_cellseg3d/_tests/test_models.py | 88 +------------------------- 1 file changed, 2 insertions(+), 86 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index ec7462db..e2ba32e0 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,30 +1,8 @@ -import numpy as np -import torch -from numpy.random import PCG64, Generator - -from napari_cellseg3d.code_models.crf import ( - CRFWorker, - correct_shape_for_crf, - crf_batch, - crf_with_config, -) -from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss -from napari_cellseg3d.config import MODEL_LIST, CRFConfig - -rand_gen = Generator(PCG64(12345)) - - -def test_correct_shape_for_crf(): - test = rand_gen.random(size=(1, 1, 8, 8, 8)) - assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) - test = rand_gen.random(size=(8, 8, 8)) - assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) +from napari_cellseg3d.config import MODEL_LIST def test_model_list(): - for model_name in MODEL_LIST: - # if model_name=="test": - # continue + for model_name in MODEL_LIST.keys(): dims = 128 test = MODEL_LIST[model_name]( input_img_size=[dims, dims, dims], @@ -33,65 +11,3 @@ def test_model_list(): dropout_prob=0.3, ) assert isinstance(test, MODEL_LIST[model_name]) - - -def test_soft_ncuts_loss(): - dims = 8 - labels = torch.rand([1, 1, dims, dims, dims]) - - loss = SoftNCutsLoss( - data_shape=[dims, dims, dims], - device="cpu", - o_i=4, - o_x=4, - radius=2, - ) - - res = loss.forward(labels, labels) - assert isinstance(res, torch.Tensor) - assert 0 <= res <= 1 - - -def test_crf_batch(): - dims = 8 - mock_image = rand_gen.random(size=(1, dims, dims, dims)) - mock_label = rand_gen.random(size=(2, dims, dims, dims)) - config = CRFConfig() - - result = crf_batch( - np.array([mock_image, mock_image, mock_image]), - np.array([mock_label, mock_label, mock_label]), - sa=config.sa, - sb=config.sb, - sg=config.sg, - w1=config.w1, - w2=config.w2, - ) - - assert result.shape == (3, 2, dims, dims, dims) - - -def test_crf_config(): - dims = 8 - mock_image = rand_gen.random(size=(1, dims, dims, dims)) - mock_label = rand_gen.random(size=(2, dims, dims, dims)) - config = CRFConfig() - - result = crf_with_config(mock_image, mock_label, config) - assert result.shape == mock_label.shape - - -def test_crf_worker(qtbot): - dims = 8 - mock_image = rand_gen.random(size=(1, dims, dims, dims)) - mock_label = rand_gen.random(size=(2, dims, dims, dims)) - assert len(mock_label.shape) == 4 - crf = CRFWorker([mock_image], [mock_label]) - - def on_yield(result): - assert len(result.shape) == 4 - assert len(mock_label.shape) == 4 - assert result.shape[-3:] == mock_label.shape[-3:] - - result = next(crf._run_crf_job()) - on_yield(result) From b280f4b59707ade893523da3d28f2c650e88017e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 14:42:56 +0200 Subject: [PATCH 514/577] Run full suite of pre-commit hooks --- docs/res/guides/custom_model_template.rst | 2 -- napari_cellseg3d/code_models/instance_segmentation.py | 2 ++ napari_cellseg3d/code_models/models/model_SwinUNetR.py | 1 + napari_cellseg3d/code_models/models/model_WNet.py | 3 ++- napari_cellseg3d/code_models/models/wnet/crf.py | 8 +++----- napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py | 8 ++++---- napari_cellseg3d/code_models/workers.py | 4 ---- napari_cellseg3d/config.py | 1 - napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 - 9 files changed, 12 insertions(+), 18 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index ddfb269f..35d21137 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -10,5 +10,3 @@ Advanced : Declaring a custom model **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute :: - - diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index c97d7ea7..57f0403b 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -13,6 +13,8 @@ # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index f582e367..c687eac2 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,5 @@ from monai.networks.nets import SwinUNETR + from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index dffa3b44..750b8bdb 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -22,7 +22,8 @@ def __init__( def forward(self, x): """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference.""" + Done this way to allow inference on the encoder only when called by sliding_window_inference. + """ enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index ca11fba2..2ac0875d 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,11 +12,9 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - unary_from_softmax, - create_pairwise_gaussian, - create_pairwise_bilateral, -) +from pydensecrf.utils import create_pairwise_bilateral +from pydensecrf.utils import create_pairwise_gaussian +from pydensecrf.utils import unary_from_softmax __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 6a625355..4e84579f 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -1,15 +1,15 @@ """ Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. -The implementation was adapted and approximated to reduce computational and memory cost. +The implementation was adapted and approximated to reduce computational and memory cost. This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. """ import math + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -import numpy as np from scipy.stats import norm __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" @@ -56,7 +56,7 @@ def __init__(self, data_shape, device, o_i, o_x, radius=None): # self.distances, self.indexes = self.get_distances() """ - + # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 11da820f..b452449e 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1,10 +1,8 @@ import platform -import time import typing as t from dataclasses import dataclass from math import ceil from pathlib import Path -import typing as t import numpy as np import torch @@ -45,8 +43,6 @@ # from napari.qt.threading import thread_worker # threads from napari.qt.threading import GeneratorWorker - -# from napari.qt.threading import thread_worker from napari.qt.threading import WorkerBaseSignals # Qt diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 6c749096..7250fe78 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -14,7 +14,6 @@ from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ from napari_cellseg3d.code_models.models.model_VNet import VNet_ from napari_cellseg3d.code_models.models.model_WNet import WNet_ - from napari_cellseg3d.utils import LOGGER logger = LOGGER diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..48249a94 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,5 +1,4 @@ import os - import napari import numpy as np import scipy.ndimage as ndimage From c2b5168ce6b9f0a67eb9f139b950d267660158e7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 15:27:18 +0200 Subject: [PATCH 515/577] Patch for tests action + style --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/code_models/instance_segmentation.py | 6 ++++-- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- napari_cellseg3d/dev_scripts/artefact_labeling.py | 1 + napari_cellseg3d/utils.py | 1 + 5 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5401bfd0..f230c9ec 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -16,6 +16,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet workflow_dispatch: jobs: diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 57f0403b..78620c2e 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -13,14 +13,16 @@ # from skimage.measure import mesh_surface_area # from skimage.measure import marching_cubes from tifffile import imread -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import sphericity_axis +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 750b8bdb..4a9ff70d 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -11,7 +11,7 @@ def __init__( out_channels=1, num_classes=2, device="cpu", - **kwargs + **kwargs, ): super().__init__( device=device, diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 48249a94..c7e6c6ee 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ import os + import napari import numpy as np import scipy.ndimage as ndimage diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index b64adcb9..e4330bad 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from pathlib import Path + import numpy as np from monai.transforms import Zoom from skimage import io From 03766211304dd30d7c72225f6ed6b57c22f774f9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 23 Apr 2023 16:03:29 +0200 Subject: [PATCH 516/577] Add softNCuts basic test --- napari_cellseg3d/_tests/test_models.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index e2ba32e0..9280b230 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,3 +1,6 @@ +import torch + +from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST @@ -11,3 +14,20 @@ def test_model_list(): dropout_prob=0.3, ) assert isinstance(test, MODEL_LIST[model_name]) + + +def test_soft_ncuts_loss(): + dims = 8 + labels = torch.rand([1, 1, dims, dims, dims]) + + loss = SoftNCutsLoss( + data_shape=[dims, dims, dims], + device="cpu", + o_i=4, + o_x=4, + radius=2, + ) + + res = loss.forward(labels, labels) + assert isinstance(res, torch.Tensor) + # assert res > 0 From 181a9346fccb8708af15951581f17659189ec66e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 09:41:15 +0200 Subject: [PATCH 517/577] Added crf Co-Authored-By: Nevexios <72894299+nevexios@users.noreply.github.com> --- napari_cellseg3d/code_models/crf.py | 123 ++-------------------------- pyproject.toml | 3 + 2 files changed, 10 insertions(+), 116 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index b362246a..13f489c7 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -9,32 +9,25 @@ Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. """ + from warnings import warn +import numpy as np + try: import pydensecrf.densecrf as dcrf - from pydensecrf.utils import ( - create_pairwise_bilateral, - create_pairwise_gaussian, - unary_from_softmax, - ) + from pydensecrf.utils import create_pairwise_bilateral + from pydensecrf.utils import create_pairwise_gaussian + from pydensecrf.utils import unary_from_softmax CRF_INSTALLED = True except ImportError: warn( "pydensecrf not installed, CRF post-processing will not be available. " - "Please install by running pip install cellseg3d[crf]", - stacklevel=1, + "Please install by running pip install cellseg3d[crf]" ) CRF_INSTALLED = False - -import numpy as np -from napari.qt.threading import GeneratorWorker - -from napari_cellseg3d.config import CRFConfig -from napari_cellseg3d.utils import LOGGER as logger - __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ "Yves Paychère", @@ -53,21 +46,6 @@ ] -def correct_shape_for_crf(image, desired_dims=4): - logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}") - logger.debug(f"Image shape: {image.shape}") - if len(image.shape) > desired_dims: - # if image.shape[0] > 1: - # raise ValueError( - # f"Image shape {image.shape} might have several channels" - # ) - image = np.squeeze(image, axis=0) - elif len(image.shape) < desired_dims: - image = np.expand_dims(image, axis=0) - logger.debug(f"Corrected image shape: {image.shape}") - return image - - def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): """CRF post-processing step for the W-Net, applied to a batch of images. @@ -81,8 +59,6 @@ def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): Returns: np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. """ - if not CRF_INSTALLED: - return None return np.stack( [ @@ -104,16 +80,10 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - w1 (float): weight of the appearance/bilateral kernel. - w2 (float): weight of the smoothness/gaussian kernel. Returns: np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. """ - - if not CRF_INSTALLED: - return None - d = dcrf.DenseCRF( image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] ) @@ -150,82 +120,3 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): return np.array(Q).reshape( (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) ) - - -def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info): - if config is None: - config = CRFConfig() - if image.shape[-3:] != prob.shape[-3:]: - raise ValueError( - f"Image and probability shapes do not match: {image.shape} vs {prob.shape}" - f" (expected {image.shape[-3:]} == {prob.shape[-3:]})" - ) - - image = correct_shape_for_crf(image) - prob = correct_shape_for_crf(prob) - - if log is not None: - log("Running CRF post-processing step") - log(f"Image shape : {image.shape}") - log(f"Labels shape : {prob.shape}") - - return crf( - image, - prob, - config.sa, - config.sb, - config.sg, - config.w1, - config.w2, - config.n_iters, - ) - - -class CRFWorker(GeneratorWorker): - """Worker for the CRF post-processing step for the W-Net.""" - - def __init__( - self, - images_list: list, - labels_list: list, - config: CRFConfig = None, - log=None, - ): - super().__init__(self._run_crf_job) - - self.images = images_list - self.labels = labels_list - if config is None: - self.config = CRFConfig() - else: - self.config = config - self.log = log - - def _run_crf_job(self): - """Runs the CRF post-processing step for the W-Net.""" - if not CRF_INSTALLED: - raise ImportError("pydensecrf is not installed.") - - if len(self.images) != len(self.labels): - raise ValueError("Number of images and labels must be the same.") - - for i in range(len(self.images)): - if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: - raise ValueError("Image and labels must have the same shape.") - - im = correct_shape_for_crf(self.images[i]) - prob = correct_shape_for_crf(self.labels[i]) - - logger.debug(f"image shape : {im.shape}") - logger.debug(f"labels shape : {prob.shape}") - - yield crf( - im, - prob, - self.config.sa, - self.config.sb, - self.config.sg, - self.config.w1, - self.config.w2, - n_iter=self.config.n_iters, - ) diff --git a/pyproject.toml b/pyproject.toml index d2a2adbb..d9a46ccf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ profile = "black" line_length = 79 [project.optional-dependencies] +crf = [ +"git+https://github.com/lucasb-eyer/pydensecrf.git", +] dev = [ "isort", "black", From 164eac6a79751222d0fab7a6b10e5da518e1aec5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 10:08:46 +0200 Subject: [PATCH 518/577] More pre-commit checks --- .pre-commit-config.yaml | 10 +-- napari_cellseg3d/_tests/fixtures.py | 4 +- napari_cellseg3d/_tests/test_plugin_utils.py | 6 +- napari_cellseg3d/_tests/test_utils.py | 7 +- .../_tests/test_weight_download.py | 6 +- napari_cellseg3d/code_models/crf.py | 11 +-- .../code_models/instance_segmentation.py | 6 +- .../code_models/models/unet/model.py | 1 - .../code_models/models/wnet/crf.py | 8 ++- napari_cellseg3d/code_models/workers.py | 12 ++-- .../code_plugins/plugin_convert.py | 14 ++-- napari_cellseg3d/code_plugins/plugin_crop.py | 1 - .../code_plugins/plugin_model_inference.py | 16 +++-- .../code_plugins/plugin_model_training.py | 6 +- .../code_plugins/plugin_review.py | 10 +-- .../code_plugins/plugin_utilities.py | 16 ++--- .../dev_scripts/artefact_labeling.py | 3 +- .../dev_scripts/correct_labels.py | 3 +- napari_cellseg3d/dev_scripts/thread_test.py | 3 - napari_cellseg3d/interface.py | 72 +++++++------------ napari_cellseg3d/utils.py | 4 +- pyproject.toml | 7 +- 22 files changed, 108 insertions(+), 118 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..61ecaae5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,11 +5,11 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] +# - repo: https://github.com/pycqa/isort +# rev: 5.12.0 +# hooks: +# - id: isort +# args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index ab8d0cb2..b3044799 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -1,5 +1,3 @@ -import warnings - from qtpy.QtWidgets import QTextEdit from napari_cellseg3d.utils import LOGGER as logger @@ -15,7 +13,7 @@ def print_and_log(self, text, printing=None): print(text) def warn(self, warning): - warnings.warn(warning) + logger.warning(warning) def error(self, e): raise (e) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 5d5ada20..0f183fa4 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -3,8 +3,10 @@ import numpy as np from tifffile import imread -from napari_cellseg3d.code_plugins.plugin_utilities import Utilities -from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS +from napari_cellseg3d.code_plugins.plugin_utilities import ( + UTILITIES_WIDGETS, + Utilities, +) rand_gen = Generator(PCG64(12345)) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 12720688..05b84b08 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,5 +1,5 @@ +import os from functools import partial -from pathlib import Path import numpy as np import torch @@ -35,7 +35,8 @@ def test_fill_list_in_between(): fill = partial(utils.fill_list_in_between, n=2, fill_value="") - assert fill(test_list) == res + assert fill(list) == res + def test_align_array_sizes(): im = np.zeros((128, 512, 256)) @@ -88,7 +89,7 @@ def test_get_padding_dim(): # "The padding value is currently 2048." # ) # - pad = utils.get_padding_dim(size) + # pad = utils.get_padding_dim(size) # # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index f3ba6654..042c9524 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,5 +1,7 @@ -from napari_cellseg3d.code_models.model_workers import PRETRAINED_WEIGHTS_DIR -from napari_cellseg3d.code_models.model_workers import WeightsDownloader +from napari_cellseg3d.code_models.model_workers import ( + PRETRAINED_WEIGHTS_DIR, + WeightsDownloader, +) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 13f489c7..fc1e0b90 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -16,15 +16,18 @@ try: import pydensecrf.densecrf as dcrf - from pydensecrf.utils import create_pairwise_bilateral - from pydensecrf.utils import create_pairwise_gaussian - from pydensecrf.utils import unary_from_softmax + from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, + ) CRF_INSTALLED = True except ImportError: warn( "pydensecrf not installed, CRF post-processing will not be available. " - "Please install by running pip install cellseg3d[crf]" + "Please install by running pip install cellseg3d[crf]", + stacklevel=1, ) CRF_INSTALLED = False diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 78620c2e..b4177ec0 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -5,8 +5,7 @@ import numpy as np import pyclesperanto_prototype as cle from qtpy.QtWidgets import QWidget -from skimage.measure import label -from skimage.measure import regionprops +from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed @@ -15,9 +14,8 @@ from tifffile import imread from napari_cellseg3d import interface as ui -from napari_cellseg3d.utils import fill_list_in_between from napari_cellseg3d.utils import LOGGER as logger -from napari_cellseg3d.utils import sphericity_axis +from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis # from skimage.measure import marching_cubes # from skimage.measure import mesh_surface_area diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index ee566be7..9591d054 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -5,7 +5,6 @@ create_decoders, create_encoders, ) -from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py index 2ac0875d..004db3a1 100644 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ b/napari_cellseg3d/code_models/models/wnet/crf.py @@ -12,9 +12,11 @@ import numpy as np import pydensecrf.densecrf as dcrf -from pydensecrf.utils import create_pairwise_bilateral -from pydensecrf.utils import create_pairwise_gaussian -from pydensecrf.utils import unary_from_softmax +from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, +) __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index b452449e..5f34f9c3 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -42,8 +42,7 @@ # from napari.qt.threading import thread_worker # threads -from napari.qt.threading import GeneratorWorker -from napari.qt.threading import WorkerBaseSignals +from napari.qt.threading import GeneratorWorker, WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -53,11 +52,10 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils - -# local -from napari_cellseg3d.code_models.model_instance_seg import ImageStats -from napari_cellseg3d.code_models.model_instance_seg import volume_stats +from napari_cellseg3d.code_models.model_instance_seg import ( + ImageStats, + volume_stats, +) logger = utils.LOGGER diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 2dc8f07c..ce4cecda 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -3,14 +3,16 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread +from tifffile import imread, imwrite import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_instance_seg import threshold -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceWidgets, + clear_small_objects, + threshold, + to_semantic, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder MAX_W = ui.UTILS_MAX_WIDTH @@ -79,7 +81,7 @@ def show_result(viewer, layer, image, name): logger.debug("Added resulting label layer") viewer.add_labels(image, name=name) else: - warnings.warn( + logger.warning( f"Results not shown, unsupported layer type {type(layer)}" ) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 5b09ad3f..957936ce 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,4 +1,3 @@ -from math import floor from pathlib import Path import napari diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 934f51e7..218dbeaa 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,15 +10,17 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.instance_segmentation import ( +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.model_instance_seg import ( InstanceMethod, InstanceWidgets, ) -from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod -from napari_cellseg3d.code_models.model_instance_seg import InstanceWidgets -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_models.model_workers import ( + InferenceResult, + InferenceWorker, +) + +logger = utils.LOGGER class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -587,7 +589,7 @@ def start(self): if not self._check_results_path(save_path): msg = f"ERROR: please set valid results path. Current path is {save_path}" self.log.print_and_log(msg) - warnings.warn(msg) + logger.warning(msg) else: if self.results_path is None: self.results_path = save_path diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index eed04a57..dd346cfd 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -32,7 +32,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.workers import ( +from napari_cellseg3d.code_models.model_workers import ( TrainingReport, TrainingWorker, ) @@ -419,7 +419,9 @@ def check_ready(self): * False and displays a warning if not """ - if self.images_filepaths == [] and self.labels_filepaths != []: + if self.images_filepaths != [] and self.labels_filepaths != []: + return True + else: logger.warning("Image and label paths are not correctly set") return False return True diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 23bcb4c5..235595e4 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -16,7 +16,6 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager @@ -179,10 +178,11 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: - logger.warning( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None: + if cfg.image.shape != cfg.labels.shape: + logger.warning( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 127ad0d7..18e06fa3 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -5,17 +5,17 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget # local import napari_cellseg3d.interface as ui -from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils -from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils -from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils +from napari_cellseg3d.code_plugins.plugin_convert import ( + AnisoUtils, + RemoveSmallUtils, + ThresholdUtils, + ToInstanceUtils, + ToSemanticUtils, +) from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index c7e6c6ee..b4712aec 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -4,8 +4,7 @@ import numpy as np import scipy.ndimage as ndimage from skimage.filters import threshold_otsu -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from napari_cellseg3d.code_models.model_instance_seg import binary_watershed diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2f079d09..2ab60332 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -8,8 +8,7 @@ import numpy as np import scipy.ndimage as ndimage from napari.qt.threading import thread_worker -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index dd3ff4e5..b8dbc442 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -2,7 +2,6 @@ import napari from napari.qt.threading import thread_worker -from numpy.random import PCG64, Generator from qtpy.QtWidgets import ( QGridLayout, QLabel, @@ -13,8 +12,6 @@ QWidget, ) -rand_gen = Generator(PCG64(12345)) - @thread_worker def two_way_communication_with_args(start, end): diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d25f02d3..2e6c4e78 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -10,32 +10,30 @@ from qtpy import QtCore # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt -from qtpy.QtCore import QUrl -from qtpy.QtGui import QCursor -from qtpy.QtGui import QDesktopServices -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QCheckBox -from qtpy.QtWidgets import QComboBox -from qtpy.QtWidgets import QDoubleSpinBox -from qtpy.QtWidgets import QFileDialog -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QGroupBox -from qtpy.QtWidgets import QHBoxLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QMenu -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QRadioButton -from qtpy.QtWidgets import QScrollArea -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QSlider -from qtpy.QtWidgets import QSpinBox -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtCore import QObject, Qt, QUrl +from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLayout, + QLineEdit, + QMenu, + QPushButton, + QRadioButton, + QScrollArea, + QSizePolicy, + QSlider, + QSpinBox, + QTextEdit, + QVBoxLayout, + QWidget, +) # Local from napari_cellseg3d import utils @@ -317,22 +315,6 @@ def error(self, error, msg=None): finally: self.lock.release() - def error(self, error, msg=None): - """Show exception and message from another thread""" - self.lock.acquire() - try: - logger.error(error, exc_info=True) - if msg is not None: - self.print_and_log(f"{msg} : {error}", printing=False) - else: - self.print_and_log( - f"Excepetion caught in another thread : {error}", - printing=False, - ) - raise error - finally: - self.lock.release() - ############## # UI elements @@ -828,7 +810,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: logger.warning("Please select a valid layer !") - return None + return return self.layer().data @@ -1068,7 +1050,7 @@ def make_n_spinboxes( boxes = [] for _i in range(n): - box = class_(min_value, max_value, default, step, parent, fixed) + box = class_(min, max, default, step, parent, fixed) boxes.append(box) return boxes @@ -1226,7 +1208,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, possible_paths: list = (), - file_extension: str = "Image file (*.tif *.tiff)", + filetype: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index e4330bad..7993247c 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -308,14 +308,14 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # LOGGER.warning( + # logger.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) pad = 2**n n += 1 - if pad >= 1024: + if pad >= 256: LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." diff --git a/pyproject.toml b/pyproject.toml index d9a46ccf..8e7187f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,12 @@ where = ["."] "*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] [tool.ruff] -# Never enforce `E501` (line length violations). +select = [ + "E", "F", "W", + "I", + "B", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) ignore = ["E501", "E741"] [tool.black] From ebdfd7c593ce483e1ca15d4449fe61efb4f5c2b4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:29:42 +0200 Subject: [PATCH 519/577] Functional CRF --- napari_cellseg3d/_tests/test_models.py | 39 +++++ napari_cellseg3d/code_models/crf.py | 98 +++++++++++- .../code_models/instance_segmentation.py | 6 +- napari_cellseg3d/code_models/workers.py | 28 ++-- .../code_plugins/plugin_convert.py | 146 +++++------------- napari_cellseg3d/code_plugins/plugin_crf.py | 34 +--- napari_cellseg3d/code_plugins/plugin_crop.py | 7 +- .../code_plugins/plugin_model_inference.py | 18 ++- .../code_plugins/plugin_utilities.py | 1 + napari_cellseg3d/utils.py | 12 +- 10 files changed, 227 insertions(+), 162 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 9280b230..1fc15872 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,9 +1,18 @@ +import numpy as np import torch +from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +def test_correct_shape_for_crf(): + test = np.random.rand(1, 1, 8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + test = np.random.rand(8, 8, 8) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + + def test_model_list(): for model_name in MODEL_LIST.keys(): dims = 128 @@ -31,3 +40,33 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) # assert res > 0 + + +def test_crf(qtbot): + dims = 8 + mock_image = np.random.rand(1, dims, dims, dims) + mock_label = np.random.rand(2, dims, dims, dims) + + crf = CRFWorker(mock_image, mock_label) + + def on_yield(result): + assert isinstance(result, np.ndarray) + assert result.shape[-3:] == mock_label.shape[-3:] + + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) + + mock_image = mock_image[0] + mock_label = mock_label[0] + + crf = CRFWorker(mock_image, mock_label) + crf.yielded.connect(on_yield) + crf.start() + with qtbot.waitSignal( + signal=crf.finished, timeout=60000, raising=False + ) as blocker: + blocker.connect(crf.errored) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index fc1e0b90..a0146a5e 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -9,11 +9,8 @@ Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. """ - from warnings import warn -import numpy as np - try: import pydensecrf.densecrf as dcrf from pydensecrf.utils import ( @@ -31,6 +28,12 @@ ) CRF_INSTALLED = False + +import numpy as np +from napari.qt.threading import GeneratorWorker + +from napari_cellseg3d.config import CRFConfig + __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ "Yves Paychère", @@ -49,6 +52,16 @@ ] +def correct_shape_for_crf(image): + if len(image.shape) == 4: + return image + if len(image.shape) > 4: + image = np.squeeze(image, axis=0) + if len(image.shape) < 4: + image = np.expand_dims(image, axis=0) + return correct_shape_for_crf(image) + + def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): """CRF post-processing step for the W-Net, applied to a batch of images. @@ -62,6 +75,8 @@ def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): Returns: np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. """ + if not CRF_INSTALLED: + return None return np.stack( [ @@ -83,10 +98,16 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + w1 (float): weight of the appearance/bilateral kernel. + w2 (float): weight of the smoothness/gaussian kernel. Returns: np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. """ + + if not CRF_INSTALLED: + return None + d = dcrf.DenseCRF( image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] ) @@ -123,3 +144,74 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): return np.array(Q).reshape( (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) ) + + +def crf_with_config(image, prob, config: CRFConfig = None): + if config is None: + config = CRFConfig() + if image.shape[-3:] != prob.shape[-3:]: + raise ValueError( + f"Image and probability shapes do not match: {image.shape} vs {prob.shape}" + f" (expected {image.shape[-3:]} == {prob.shape[-3:]})" + ) + + image = correct_shape_for_crf(image) + + return crf( + image, + prob, + config.sa, + config.sb, + config.sg, + config.w1, + config.w2, + config.n_iters, + ) + + +class CRFWorker(GeneratorWorker): + """Worker for the CRF post-processing step for the W-Net.""" + + def __init__( + self, + images_list, + labels_list, + config: CRFConfig = None, + log=None, + ): + super().__init__(self._run_crf_job) + + self.images = images_list + self.labels = labels_list + if config is None: + self.config = CRFConfig() + else: + self.config = config + self.log = log + + # TODO(cyril) : add progress bar into log ? or do it in inference + def _run_crf_job(self): + """Runs the CRF post-processing step for the W-Net.""" + if not CRF_INSTALLED: + raise ImportError("pydensecrf is not installed.") + + for image, labels in zip(self.images, self.labels): + if len(image.shape) == 3: + image = np.expand_dims(image, axis=0) + + if len(labels.shape) == 3: + labels = np.expand_dims(labels, axis=0) + + if image.shape[-3:] != labels.shape[-3:]: + raise ValueError("Image and labels must have the same shape.") + + yield crf( + image, + labels, + self.config.sa, + self.config.sb, + self.config.sg, + self.config.w1, + self.config.w2, + n_iter=self.config.n_iters, + ) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index b4177ec0..1bbaf659 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -65,7 +65,7 @@ def __init__( 1, divide_factor=100, text_label="", - parent=None, + parent=widget_parent, ), ) self.sliders.append(getattr(self, widget)) @@ -76,7 +76,9 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(text_label="", parent=None), + ui.DoubleIncrementCounter( + text_label="", parent=widget_parent + ), ) self.counters.append(getattr(self, widget)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 5f34f9c3..1edd9976 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -52,6 +52,7 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import crf_with_config from napari_cellseg3d.code_models.model_instance_seg import ( ImageStats, volume_stats, @@ -250,6 +251,7 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None + crf_results: np.array = None stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -592,9 +594,12 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - total_dims = len(semantic_labels.shape) - 3 + extra_dims = len(semantic_labels.shape) - 3 semantic_labels = np.swapaxes( - semantic_labels, 0 + total_dims, 2 + total_dims + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims ) return InferenceResult( @@ -650,8 +655,7 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"{additional_info}" - + f"Prediction_{i+1}" + + f"{additional_info}_Prediction_{i+1}" + original_filename + self.config.model_info.name + f"_{time}_" @@ -758,15 +762,15 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): ) def run_crf(self, image, labels, image_id=0): + self.log(f"IMAGE SHAPE : {image.shape}") + self.log(f"LABEL SHAPE : {labels.shape}") + try: crf_results = crf_with_config( - image, labels, config=self.config.crf_config, log=self.log + image, labels, config=self.config.crf_config ) self.save_image( - crf_results, - i=image_id, - additional_info="CRF_", - from_layer=True, + crf_results, i=image_id, additional_info="CRF", from_layer=True ) return crf_results except ValueError as e: @@ -806,9 +810,15 @@ def inference_on_layer(self, image, model, post_process_transforms): instance_labels_results.append(instance_labels) stats_results.append(stats) + if self.config.use_crf: + crf_results = self.run_crf(image, out) + else: + crf_results = None + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels_results, + crf_results=crf_results, from_layer=True, stats=stats_results, ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index ce4cecda..3fa21508 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -3,7 +3,7 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread import napari_cellseg3d.interface as ui from napari_cellseg3d import utils @@ -21,71 +21,6 @@ logger = utils.LOGGER -def save_folder(results_path, folder_name, images, image_paths): - """ - Saves a list of images in a folder - - Args: - results_path: Path to the folder containing results - folder_name: Name of the folder containing results - images: List of images to save - image_paths: list of filenames of images - """ - results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False, parents=True) - - for file, image in zip(image_paths, images): - path = results_folder / Path(file).name - - imwrite( - path, - image, - ) - logger.info(f"Saved processed folder as : {results_folder}") - - -def save_layer(results_path, image_name, image): - """ - Saves an image layer at the specified path - - Args: - results_path: path to folder containing result - image_name: image name for saving - image: data array containing image - - Returns: - - """ - path = str(results_path / Path(image_name)) # TODO flexible filetype - logger.info(f"Saved as : {path}") - imwrite(path, image) - - -def show_result(viewer, layer, image, name): - """ - Adds layers to a viewer to show result to user - - Args: - viewer: viewer to add layer in - layer: type of the original layer the operation was run on, to determine whether it should be an Image or Labels layer - image: the data array containing the image - name: name of the added layer - - Returns: - - """ - if isinstance(layer, napari.layers.Image): - logger.debug("Added resulting image layer") - viewer.add_image(image, name=name) - elif isinstance(layer, napari.layers.Labels): - logger.debug("Added resulting label layer") - viewer.add_labels(image, name=name) - else: - logger.warning( - f"Results not shown, unsupported layer type {type(layer)}" - ) - - class AnisoUtils(BasePluginFolder): """Class to correct anisotropy in images""" @@ -169,7 +104,7 @@ def _start(self): utils.resize(np.array(imread(file)), zoom) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"isotropic_results_{utils.get_date_time()}", images, @@ -258,19 +193,18 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -346,7 +280,7 @@ def _start(self): to_semantic(file, is_file_path=True) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"semantic_results_{utils.get_date_time()}", images, @@ -426,19 +360,18 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - self.instance_widgets.run_method_on_channels(imread(file)) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.instance_widgets.run_method(imread(file)) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -521,19 +454,18 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif ( - self.folder_choice.isChecked() and len(self.images_filepaths) != 0 - ): - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 76194e87..3dbd47bb 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -1,4 +1,3 @@ -import contextlib from functools import partial from pathlib import Path @@ -8,11 +7,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.crf import ( - CRF_INSTALLED, - CRFWorker, - crf_with_config, -) +from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.utils import LOGGER as logger @@ -48,17 +43,6 @@ def __init__(self, parent=None): self._set_tooltips() def _build(self): - if not CRF_INSTALLED: - ui.add_widgets( - self.layout, - [ - ui.make_label( - "ERROR: CRF not installed.\nPlease refer to the documentation to install it." - ), - ], - ) - self.set_layout() - return ui.add_widgets( self.layout, [ @@ -129,10 +113,7 @@ def __init__(self, viewer, parent=None): napari.layers.Image ) # to load all crf-compatible inputs, not int only self.image_layer_loader.setVisible(True) - if CRF_INSTALLED: - self.start_button.setVisible(True) - else: - self.start_button.setVisible(False) + self.start_button.setVisible(True) self.result_layer = None self.result_name = None @@ -179,11 +160,6 @@ def _build(self): def make_config(self): return self.crf_params_widget.make_config() - def print_config(self): - logger.info("CRF config:") - for item in self.make_config().__dict__.items(): - logger.info(f"{item[0]}: {item[1]}") - def _check_ready(self): if len(self.label_layer_loader.layer_list) < 1: logger.warning("No label layer loaded") @@ -239,7 +215,7 @@ def _start(self): self.result_layer = self.label_layer_loader.layer() self.result_name = self.label_layer_loader.layer_name() - utils.mkdir_from_str(self.results_path) + self.results_path.mkdir(exist_ok=True, parents=True) image_list = [self.image_layer_loader.layer_data()] labels_list = [self.label_layer_loader.layer_data()] @@ -278,10 +254,6 @@ def _on_start(self): def _on_finish(self): self.worker = None - with contextlib.suppress(RuntimeError): - self.start_button.setText("Start") - - # should only happen when testing def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 957936ce..0c2e4042 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -193,7 +193,12 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + ui.ScrollArea.make_scrollable( + layout, + self, + max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT], + min_wh=[200, 200], + ) self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 218dbeaa..3927df5c 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -19,6 +19,7 @@ InferenceResult, InferenceWorker, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFParamsWidget logger = utils.LOGGER @@ -201,6 +202,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ################## # instance segmentation widgets self.instance_widgets = InstanceWidgets(parent=self) + self.crf_widgets = CRFParamsWidget(parent=self) self.use_instance_choice = ui.CheckBox( "Run instance segmentation", @@ -470,6 +472,8 @@ def _build(self): self.crf_widgets, self.use_instance_choice, self.instance_widgets, + self.use_crf, + self.crf_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -635,6 +639,8 @@ def start(self): compute_stats=self.save_stats_to_csv_box.isChecked(), post_process_config=self.post_process_config, sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), ) ##################### ##################### @@ -870,7 +876,10 @@ def on_yield(self, result: InferenceResult): self.worker_config.post_process_config.instance.method.name ) - if result.instance_labels is not None: + if ( + len(result.instance_labels) > 0 + and self.worker_config.post_process_config.instance.enabled + ): for i, labels in enumerate(result.instance_labels): # labels = result.instance_labels method_name = ( @@ -912,5 +921,12 @@ def on_yield(self, result: InferenceResult): # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) + + if result.crf_results is not None: + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 18e06fa3..6e1a606a 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -16,6 +16,7 @@ ToInstanceUtils, ToSemanticUtils, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFWidget from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 7993247c..dfceadfb 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -2,15 +2,13 @@ from datetime import datetime from pathlib import Path +import napari import numpy as np from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian from tifffile import imread, imwrite -if TYPE_CHECKING: - import torch - LOGGER = logging.getLogger(__name__) ############### # Global logging level setting @@ -117,7 +115,7 @@ def __call__(cls, *args, **kwargs): # if filename == "tif": # return True # def read(self, data, **kwargs): -# return tfl_imread(data) +# return imread(data) # # def get_data(self, data): # return data, {} @@ -308,7 +306,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # logger.warning( + # LOGGER.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) @@ -545,9 +543,7 @@ def load_images( ) # images_original = dask_imread(filename_pattern_original) else: - images_original = tfl_imread( - filename_pattern_original - ) # tifffile imread + images_original = imread(filename_pattern_original) # tifffile imread return imread(filename_pattern_original) # tifffile imread From 106f4b7c9a0f347ef6d5aa5cc861766e255d6e3d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:37:33 +0200 Subject: [PATCH 520/577] Fix erroneous test comment, added toggle for crf - Warn if crf not installed - Fix test --- napari_cellseg3d/_tests/test_utils.py | 2 +- napari_cellseg3d/code_plugins/plugin_crf.py | 22 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 05b84b08..48550747 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -89,7 +89,7 @@ def test_get_padding_dim(): # "The padding value is currently 2048." # ) # - # pad = utils.get_padding_dim(size) + pad = utils.get_padding_dim(size) # # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 3dbd47bb..cbdacf3a 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -7,7 +7,11 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.crf import CRFWorker, crf_with_config +from napari_cellseg3d.code_models.crf import ( + CRF_INSTALLED, + CRFWorker, + crf_with_config, +) from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.utils import LOGGER as logger @@ -43,6 +47,17 @@ def __init__(self, parent=None): self._set_tooltips() def _build(self): + if not CRF_INSTALLED: + ui.add_widgets( + self.layout, + [ + ui.make_label( + "ERROR: CRF not installed.\nPlease refer to the documentation to install it." + ), + ], + ) + self.set_layout() + return ui.add_widgets( self.layout, [ @@ -113,7 +128,10 @@ def __init__(self, viewer, parent=None): napari.layers.Image ) # to load all crf-compatible inputs, not int only self.image_layer_loader.setVisible(True) - self.start_button.setVisible(True) + if CRF_INSTALLED: + self.start_button.setVisible(True) + else: + self.start_button.setVisible(False) self.result_layer = None self.result_name = None From c5bd372765b79cd17d14c0f6d4fdf8350d170879 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 17:56:08 +0200 Subject: [PATCH 521/577] Specify missing test deps --- pyproject.toml | 3 ++- tox.ini | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8e7187f5..5648ab40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ -"git+https://github.com/lucasb-eyer/pydensecrf.git", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -81,4 +81,5 @@ test = [ "coverage", "tox", "twine", + "git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 4b04a5bc..c6e863e1 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = magicgui pytest-qt qtpy + "git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] ; opencv-python extras = crf From ae5810b3482c4e682e8514c0f840f856dfc566a2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:02:31 +0200 Subject: [PATCH 522/577] Trying to fix deps on Git --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5648ab40..73fc862c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", ] diff --git a/tox.ini b/tox.ini index c6e863e1..037e385e 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - "git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" ; pyopencl[pocl] ; opencv-python extras = crf From 1f7dae96fc2f7855ebd0d4a20f119e797b2a3ff9 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:04:33 +0200 Subject: [PATCH 523/577] Removed master link to pydensecrf --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73fc862c..8d9d6bf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", ] diff --git a/tox.ini b/tox.ini index 037e385e..bcd3a2c0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" ; pyopencl[pocl] ; opencv-python extras = crf From fd1c6c62894bb8e33b40f4decade78fc027c0ae1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:07:23 +0200 Subject: [PATCH 524/577] Use commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d9d6bf4..0cc237e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", ] diff --git a/tox.ini b/tox.ini index bcd3a2c0..b496b9b0 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" ; pyopencl[pocl] ; opencv-python extras = crf From 4401b8973193715ef769884d77b98811e5431925 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:09:27 +0200 Subject: [PATCH 525/577] Removed commit hash --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0cc237e5..09ed8585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", ] diff --git a/tox.ini b/tox.ini index b496b9b0..024c8955 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@0d53acb" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" ; pyopencl[pocl] ; opencv-python extras = crf From c0f851a4e76aa7674bc4aecdf8d8128174aa187d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:11:27 +0200 Subject: [PATCH 526/577] Removed master --- pyproject.toml | 4 ++-- tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 09ed8585..db39904b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master", + "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", ] diff --git a/tox.ini b/tox.ini index 024c8955..64350cc6 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf@master" + pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" ; pyopencl[pocl] ; opencv-python extras = crf From 781cd0dc52049261777e3d36a114754e529756ce Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 26 Apr 2023 18:17:16 +0200 Subject: [PATCH 527/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 64350cc6..bb02bb56 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf" + pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] ; opencv-python extras = crf From a3c9dbfda744190c7dfdfc60a60c40e6abc68734 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 09:06:23 +0200 Subject: [PATCH 528/577] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db39904b..d223072a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ line_length = 79 [project.optional-dependencies] crf = [ - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] dev = [ "isort", @@ -81,5 +81,5 @@ test = [ "coverage", "tox", "twine", - "pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] From 8c93b6534f9f95c1d8f9fcf22b6ea40d01a54ed3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 28 Apr 2023 17:41:05 +0200 Subject: [PATCH 529/577] Fixes and improvements - More CRF info - Added error handling to scheduler rate - Added ETA to training - Updated padding warning trigger size --- napari_cellseg3d/code_models/crf.py | 30 ++++++++++------ .../code_models/models/model_VNet.py | 2 +- napari_cellseg3d/code_models/workers.py | 36 ++++++++++++------- napari_cellseg3d/code_plugins/plugin_crf.py | 6 ++++ .../code_plugins/plugin_model_inference.py | 3 ++ .../code_plugins/plugin_model_training.py | 2 +- napari_cellseg3d/utils.py | 2 +- 7 files changed, 54 insertions(+), 27 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index a0146a5e..1b8dce28 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -33,6 +33,7 @@ from napari.qt.threading import GeneratorWorker from napari_cellseg3d.config import CRFConfig +from napari_cellseg3d.utils import LOGGER as logger __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ @@ -52,12 +53,16 @@ ] -def correct_shape_for_crf(image): - if len(image.shape) == 4: +def correct_shape_for_crf(image, desired_dims=4): + if len(image.shape) == desired_dims: return image - if len(image.shape) > 4: + if len(image.shape) > desired_dims: + if image.shape[0] > 1: + raise ValueError( + f"Image shape {image.shape} might have several channels" + ) image = np.squeeze(image, axis=0) - if len(image.shape) < 4: + if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) return correct_shape_for_crf(image) @@ -146,7 +151,7 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): ) -def crf_with_config(image, prob, config: CRFConfig = None): +def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info): if config is None: config = CRFConfig() if image.shape[-3:] != prob.shape[-3:]: @@ -156,6 +161,12 @@ def crf_with_config(image, prob, config: CRFConfig = None): ) image = correct_shape_for_crf(image) + prob = correct_shape_for_crf(prob) + + if log is not None: + log("Running CRF post-processing step") + log(f"Image shape : {image.shape}") + log(f"Labels shape : {prob.shape}") return crf( image, @@ -196,15 +207,12 @@ def _run_crf_job(self): raise ImportError("pydensecrf is not installed.") for image, labels in zip(self.images, self.labels): - if len(image.shape) == 3: - image = np.expand_dims(image, axis=0) - - if len(labels.shape) == 3: - labels = np.expand_dims(labels, axis=0) - if image.shape[-3:] != labels.shape[-3:]: raise ValueError("Image and labels must have the same shape.") + image = correct_shape_for_crf(image) + labels = correct_shape_for_crf(labels) + yield crf( image, labels, diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 41554e80..7aa6476e 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=1, **kwargs): + def __init__(self, in_channels=1, out_channels=2, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 1edd9976..fe686134 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1,4 +1,5 @@ import platform +import time import typing as t from dataclasses import dataclass from math import ceil @@ -648,7 +649,7 @@ def save_image( filetype = self.config.filetype else: original_filename = "_" - filetype = "" + filetype = ".tif" time = utils.get_date_time() @@ -762,12 +763,9 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): ) def run_crf(self, image, labels, image_id=0): - self.log(f"IMAGE SHAPE : {image.shape}") - self.log(f"LABEL SHAPE : {labels.shape}") - try: crf_results = crf_with_config( - image, labels, config=self.config.crf_config + image, labels, config=self.config.crf_config, log=self.log ) self.save_image( crf_results, i=image_id, additional_info="CRF", from_layer=True @@ -1394,14 +1392,23 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) + + factor = self.config.scheduler_factor + if factor >= 1.0: + self.log(f"Warning : scheduler factor is {factor} >= 1.0") + self.log("Setting it to 0.5") + factor = 0.5 + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode="min", - factor=self.config.scheduler_factor, + factor=factor, patience=self.config.scheduler_patience, verbose=VERBOSE_SCHEDULER, ) - dice_metric = DiceMetric(include_background=True, reduction="mean") + dice_metric = DiceMetric( + include_background=False, reduction="mean" + ) best_metric = -1 best_metric_epoch = -1 @@ -1503,12 +1510,15 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] - eta = ( - (time.time() - start_time) - * (self.config.max_epochs / (epoch + 1) - 1) - / 60 + self.log( + "ETA: " + + str( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 + ) + + "minutes" ) - self.log("ETA: " + f"{eta:.2f}" + " minutes") if ( (epoch + 1) % self.config.validation_interval == 0 @@ -1533,7 +1543,7 @@ def train(self): overlap=0.25, sw_device=self.config.device, device=self.config.device, - progress=True, + progress=False, ) except Exception as e: self.raise_error(e, "Error during validation") diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index cbdacf3a..7ac605e9 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -178,6 +178,11 @@ def _build(self): def make_config(self): return self.crf_params_widget.make_config() + def print_config(self): + logger.info("CRF config:") + for item in self.make_config().__dict__.items(): + logger.info(f"{item[0]}: {item[1]}") + def _check_ready(self): if len(self.label_layer_loader.layer_list) < 1: logger.warning("No label layer loaded") @@ -272,6 +277,7 @@ def _on_start(self): def _on_finish(self): self.worker = None + self.start_button.setText("Start") def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 3927df5c..69c518ce 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -923,6 +923,9 @@ def on_yield(self, result: InferenceResult): # ) if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) viewer.add_image( result.crf_results, name=f"CRF_results_image_{image_id}", diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index dd346cfd..0670c1c7 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -852,7 +852,7 @@ def start(self): loss_function=self.get_loss(self.loss_choice.currentText()), learning_rate=float(self.learning_rate_choice.currentText()), scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, validation_interval=self.val_interval_choice.value(), batch_size=self.batch_choice.slider_value, results_path_folder=str(results_path_folder), diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index dfceadfb..a77e7cbd 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -313,7 +313,7 @@ def get_padding_dim(image_shape, anisotropy_factor=None): pad = 2**n n += 1 - if pad >= 256: + if pad >= 1024: LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." From de939c097fb0b7e69e8d77007ea68e35af088396 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 3 May 2023 09:57:34 +0200 Subject: [PATCH 530/577] Fixes and channel labeling prototype --- napari_cellseg3d/code_models/workers.py | 27 ++-- .../extract_extra_channels_labels.py | 124 ++++++++++++++++++ 2 files changed, 137 insertions(+), 14 deletions(-) create mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index fe686134..35253f59 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -596,12 +596,14 @@ def create_inference_result( "A layer's ID should always be 0 (default value)" ) extra_dims = len(semantic_labels.shape) - 3 - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + if semantic_labels is not None: + semantic_labels = np.swapaxes( + semantic_labels, 0 + extra_dims, 2 + extra_dims + ) + if crf_results is not None: + crf_results = np.swapaxes( + crf_results, 0 + extra_dims, 2 + extra_dims + ) return InferenceResult( image_id=i + 1, @@ -1510,15 +1512,12 @@ def train(self): scheduler.step(epoch_loss) checkpoint_output = [] - self.log( - "ETA: " - + str( - (time.time() - start_time) - * (self.config.max_epochs / (epoch + 1) - 1) - / 60 - ) - + "minutes" + eta = ( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 ) + self.log("ETA: " + f"{eta:.2f}" + " minutes") if ( (epoch + 1) % self.config.validation_interval == 0 diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py new file mode 100644 index 00000000..2bd0a536 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -0,0 +1,124 @@ +import numpy as np +from skimage.filters import threshold_otsu +from skimage.segmentation import expand_labels +from tqdm import tqdm + + +def extract_labels_from_channels( + nucleus_labels: np.array, + extra_channels: list, + radius: int = 4, + threshold_factor=2, + viewer=None, +): + """ + Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. + Args: + nucleus_labels (np.array): labels for the nuclei + extra_channels (list): channels arrays to extract labels from + radius: radius in which the approximation is made + + Returns: + A list of extracted labels for each extra channel + """ + labeled_channels = {} + + contrasted_channels = [] + for channel in extra_channels: + channel = (channel - np.min(channel)) / ( + np.max(channel) - np.min(channel) + ) + threshold_brightness = threshold_otsu(channel) * threshold_factor + channel_contrasted = np.where( + channel > threshold_brightness, channel, 0 + ) + contrasted_channels.append(channel_contrasted) + if viewer is not None: + viewer.add_image( + channel_contrasted, + name="channel_contrasted", + colormap="viridis", + ) + for label_id in tqdm(np.unique(nucleus_labels)): + if label_id == 0: + continue + label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + expanded = expand_labels(label_nucleus, distance=radius) + for i, channel in enumerate(contrasted_channels): + label_contrasted = np.where(expanded != 0, channel, 0) + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels[ + f"label_{label_id}_channel_{i+1}" + ] = np.count_nonzero(labeled_channel) + if np.count_nonzero(labeled_channel) > 0 and viewer is not None: + print(np.count_nonzero(labeled_channel)) + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) + + return labeled_channels + + +if __name__ == "__main__": + from pathlib import Path + + import napari + import pandas as pd + from tifffile import imread + + image_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + ) + # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" + nuclei_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" + ) + extra_channels_path = ( + Path.home() + / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" + ) + extra_channels = [ + imread(str(path)) + for path in extra_channels_path.glob( + "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" + ) + ] + labels = imread(str(image_path)) + viewer = napari.Viewer() + + shift = 0 + viewer.add_image( + imread(str(nuclei_path))[ + shift : 32 + shift, shift : 32 + shift, shift : 32 + shift + ], + name="nuclei", + ) + viewer.add_labels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + [ + viewer.add_image( + channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + ) + for channel in extra_channels + ] + + labeled_channels = extract_labels_from_channels( + labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], + [ + c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] + for c in extra_channels + ], + radius=4, + viewer=viewer, + ) + table = pd.DataFrame( + labeled_channels.items(), columns=["name", "pixels count"] + ) + print(table) + # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] + # expanded = expand_labels(labels, 4) + # viewer.add_labels(expanded) + napari.run() From 33a6da85e944c0f70e7d3fb9f2fb6e53498861e7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 09:18:42 +0200 Subject: [PATCH 531/577] Fixes - Fixed multi-channel instance and csv stats - Fixed rotation of inference outputs - Raised max crop size --- napari_cellseg3d/code_models/workers.py | 69 +++++++++--------- .../code_plugins/plugin_model_inference.py | 72 +++++++------------ .../extract_extra_channels_labels.py | 64 +++++++++++------ napari_cellseg3d/utils.py | 6 ++ 4 files changed, 110 insertions(+), 101 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 35253f59..4a8bb913 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -595,15 +595,15 @@ def create_inference_result( raise ValueError( "A layer's ID should always be 0 (default value)" ) - extra_dims = len(semantic_labels.shape) - 3 + if semantic_labels is not None: - semantic_labels = np.swapaxes( - semantic_labels, 0 + extra_dims, 2 + extra_dims - ) + semantic_labels = utils.correct_rotation(semantic_labels) if crf_results is not None: - crf_results = np.swapaxes( - crf_results, 0 + extra_dims, 2 + extra_dims - ) + crf_results = utils.correct_rotation(crf_results) + if instance_labels is not None: + instance_labels = utils.correct_rotation( + instance_labels + ) # TODO(cyril) check if correct return InferenceResult( image_id=i + 1, @@ -629,10 +629,6 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): semantic_labels, i + 1, ) - if from_layer: - instance_labels = np.swapaxes( - instance_labels, 0, 2 - ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -658,10 +654,11 @@ def save_image( file_path = ( self.config.results_path + "/" - + f"{additional_info}_Prediction_{i+1}" + + f"{additional_info}" + + f"Prediction_{i+1}" + original_filename + self.config.model_info.name - + f"_{time}_" + + f"_{time}" + filetype ) try: @@ -688,18 +685,20 @@ def aniso_transform(self, image): return image def instance_seg( - self, to_instance, image_id=0, original_filename="layer", channel=None + self, semantic_labels, image_id=0, original_filename="layer" ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) - if channel is not None: - channel_id = f"_{channel}" + if len(semantic_labels.shape) == 4: + instance_labels = np.array( + [method.run_method(ch) for ch in semantic_labels] + ) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") else: - channel_id = "" + instance_labels = method.run_method(image=semantic_labels) if self.config.filetype == "": filetype = "" @@ -711,7 +710,6 @@ def instance_seg( + "/" + f"Instance_seg_labels_{image_id}_" + original_filename - + channel_id + "_" + self.config.model_info.name + f"_{utils.get_date_time()}" @@ -770,7 +768,10 @@ def run_crf(self, image, labels, image_id=0): image, labels, config=self.config.crf_config, log=self.log ) self.save_image( - crf_results, i=image_id, additional_info="CRF", from_layer=True + crf_results, + i=image_id, + additional_info="CRF_", + from_layer=True, ) return crf_results except ValueError as e: @@ -778,9 +779,15 @@ def run_crf(self, image, labels, image_id=0): return None def stats_csv(self, instance_labels): - if self.config.compute_stats: - stats = volume_stats(instance_labels) - return stats + try: + if self.config.compute_stats: + if len(instance_labels.shape) == 4: + stats = [volume_stats(c) for c in instance_labels] + else: + stats = [volume_stats(instance_labels)] + return stats + else: + return None except ValueError as e: self.log(f"Error occurred during stats computing : {e}") return None @@ -800,15 +807,9 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels_results = [] - stats_results = [] - - for channel in out: - instance_labels, stats = self.get_instance_result( - channel, from_layer=True - ) - instance_labels_results.append(instance_labels) - stats_results.append(stats) + instance_labels, stats = self.get_instance_result( + semantic_labels=out, from_layer=True + ) if self.config.use_crf: crf_results = self.run_crf(image, out) @@ -817,10 +818,10 @@ def inference_on_layer(self, image, model, post_process_transforms): return self.create_inference_result( semantic_labels=out, - instance_labels=instance_labels_results, + instance_labels=instance_labels, crf_results=crf_results, from_layer=True, - stats=stats_results, + stats=stats, ) # @thread_worker(connect={"errored": self.raise_error}) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 69c518ce..b63de2b8 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -472,8 +472,6 @@ def _build(self): self.crf_widgets, self.use_instance_choice, self.instance_widgets, - self.use_crf, - self.crf_widgets, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -868,68 +866,52 @@ def on_yield(self, result: InferenceResult): name=f"CRF_results_image_{image_id}", colormap="viridis", ) + if ( - result.instance_labels is not None + len(result.instance_labels) > 0 and self.worker_config.post_process_config.instance.enabled ): method_name = ( self.worker_config.post_process_config.instance.method.name ) - if ( - len(result.instance_labels) > 0 - and self.worker_config.post_process_config.instance.enabled - ): - for i, labels in enumerate(result.instance_labels): - # labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) - - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + number_cells = ( + np.unique(result.instance_labels.flatten()).size - 1 + ) # remove background - name = f"({number_cells} objects)_{method_name}_channel_{i}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - viewer.add_labels(labels, name=name) + viewer.add_labels(result.instance_labels, name=name) from napari_cellseg3d.utils import LOGGER as log - log.debug(f"len stats : {len(result.stats)}") + if result.stats is not None and isinstance( + result.stats, list + ): + log.debug(f"len stats : {len(result.stats)}") - for i, stats in enumerate(result.stats): - # stats = result.stats + for i, stats in enumerate(result.stats): + # stats = result.stats - if ( - self.worker_config.compute_stats - and stats is not None - ): - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances in channel {i} : {stats.number_objects[0]}" - ) + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) # self.log.print_and_log( # f"OBJECTS DETECTED : {number_cells}\n" # ) - - if result.crf_results is not None: - logger.debug( - f"CRF results shape : {result.crf_results.shape}" - ) - viewer.add_image( - result.crf_results, - name=f"CRF_results_image_{image_id}", - colormap="viridis", - ) except Exception as e: self.on_error(e) diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py index 2bd0a536..70ee10b6 100644 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py @@ -4,8 +4,8 @@ from tqdm import tqdm -def extract_labels_from_channels( - nucleus_labels: np.array, +def extract_labels_from_channels( # TODO add separate channels results + nuclei_labels: np.array, extra_channels: list, radius: int = 4, threshold_factor=2, @@ -14,15 +14,14 @@ def extract_labels_from_channels( """ Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. Args: - nucleus_labels (np.array): labels for the nuclei + nuclei_labels (np.array): labels for the nuclei extra_channels (list): channels arrays to extract labels from radius: radius in which the approximation is made Returns: A list of extracted labels for each extra channel """ - labeled_channels = {} - + labeled_channels = [] contrasted_channels = [] for channel in extra_channels: channel = (channel - np.min(channel)) / ( @@ -39,31 +38,54 @@ def extract_labels_from_channels( name="channel_contrasted", colormap="viridis", ) - for label_id in tqdm(np.unique(nucleus_labels)): + for label_id in tqdm(np.unique(nuclei_labels)): if label_id == 0: continue - label_nucleus = np.where(nucleus_labels == label_id, nucleus_labels, 0) + label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) expanded = expand_labels(label_nucleus, distance=radius) + restricted = np.where(expanded != 0, nuclei_labels, 0) + overlap = np.where(restricted != label_id, restricted, 0) + for i, channel in enumerate(contrasted_channels): label_contrasted = np.where(expanded != 0, channel, 0) - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels[ - f"label_{label_id}_channel_{i+1}" - ] = np.count_nonzero(labeled_channel) - if np.count_nonzero(labeled_channel) > 0 and viewer is not None: - print(np.count_nonzero(labeled_channel)) - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) + if overlap.any() != 0: + max_labeled = 0 + for overlap_id in np.unique(overlap): + if overlap_id == 0: + continue + assigned_pixels = np.count_nonzero( + np.where(overlap == overlap_id, channel, 0) + ) + if assigned_pixels > max_labeled: + max_labeled = assigned_pixels + max_label_id = overlap_id + if label_id != max_label_id: + labeled_channels.append( + np.zeros_like(label_contrasted) + ) + else: + labeled_channel = np.where(label_contrasted != 0, label_id, 0) + labeled_channels.append(labeled_channel) + if ( + np.count_nonzero(labeled_channel) > 0 + and viewer is not None + ): + viewer.add_labels( + labeled_channel, name=f"label_{label_id}_channel_{i+1}" + ) - return labeled_channels + cat_labels = np.zeros_like(nuclei_labels) + for labels in np.unique(labeled_channels): + if labels == 0: + continue + cat_labels += np.where(labels != 0, labels, 0) + return cat_labels if __name__ == "__main__": from pathlib import Path import napari - import pandas as pd from tifffile import imread image_path = ( @@ -114,10 +136,8 @@ def extract_labels_from_channels( radius=4, viewer=viewer, ) - table = pd.DataFrame( - labeled_channels.items(), columns=["name", "pixels count"] - ) - print(table) + + viewer.add_labels(labeled_channels) # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] # expanded = expand_labels(labels, 4) # viewer.add_labels(expanded) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index a77e7cbd..7e9eb625 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -203,6 +203,12 @@ def dice_coeff(y_true, y_pred): ) +def correct_rotation(image): + """Rotates the exes 0 and 2 in [DHW] section of image array""" + extra_dims = len(image) - 3 + return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, From 23261e2b291694e232eeac21fec36e4ac066d59d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 5 May 2023 14:42:02 +0200 Subject: [PATCH 532/577] Update plugin_model_inference.py --- napari_cellseg3d/code_plugins/plugin_model_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index b63de2b8..0233ac23 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -866,9 +866,8 @@ def on_yield(self, result: InferenceResult): name=f"CRF_results_image_{image_id}", colormap="viridis", ) - if ( - len(result.instance_labels) > 0 + result.instance_labels is not None and self.worker_config.post_process_config.instance.enabled ): method_name = ( From 04a44d02246533db4ff99c9256dc614952264de7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 10:16:58 +0200 Subject: [PATCH 533/577] Fixed patch_func sample number mismatch --- napari_cellseg3d/code_models/workers.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 4a8bb913..1dc1a777 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1278,6 +1278,7 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") + if self.config.do_augmentation: train_transforms = ( Compose( # TODO : figure out which ones and values ? @@ -1309,6 +1310,31 @@ def train(self): ) # self.log("Loading dataset...\n") + def get_loader_func(num_samples): + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) + ), + ), + EnsureTyped(keys=["image", "label"]), + ] + ) + if do_sampling: # if there is only one volume, split samples # TODO(cyril) : maybe implement something in user config to toggle this behavior @@ -1331,6 +1357,7 @@ def train(self): sample_loader_train = get_loader_func(num_train_samples) sample_loader_eval = get_loader_func(num_val_samples) + logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" From 314ddd463c82340bb15b30d043eee1708ab050d3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:08:52 +0200 Subject: [PATCH 534/577] Testing relabel tools --- napari_cellseg3d/dev_scripts/correct_labels.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 2ab60332..9862c3fa 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -367,8 +367,8 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): # if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") +# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") # -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +# image_path = str(im_path / "volumes/images.tif") +# gt_labels_path = str(im_path / "labels/testing_im.tif") +# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) From aa228a8b0ad7f4423a2e5a384a34faf36aa005db Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 11 May 2023 11:38:45 +0200 Subject: [PATCH 535/577] Fixes in inference --- napari_cellseg3d/code_models/workers.py | 2 ++ napari_cellseg3d/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 1dc1a777..f1ed632a 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -552,6 +552,8 @@ def model_output_wrapper(inputs): sw_device=self.config.device, device=dataset_device, overlap=window_overlap, + mode="gaussian", + sigma_scale=0.01, progress=True, ) except Exception as e: diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 7e9eb625..1f0c17ea 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -205,7 +205,7 @@ def dice_coeff(y_true, y_pred): def correct_rotation(image): """Rotates the exes 0 and 2 in [DHW] section of image array""" - extra_dims = len(image) - 3 + extra_dims = len(image.shape) - 3 return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) From db6ed07972fc00e812ad2f6ab622aae087f0677c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 14:48:14 +0200 Subject: [PATCH 536/577] add model template + fix test + wnet loading opti - test fixes - changed crf input reqs - adapted instance seg for several channels --- napari_cellseg3d/_tests/test_models.py | 10 ++- .../_tests/test_plugin_inference.py | 11 ++-- napari_cellseg3d/_tests/test_training.py | 11 ++-- napari_cellseg3d/code_models/crf.py | 11 ++-- .../code_models/instance_segmentation.py | 16 ++--- .../code_models/models/model_SwinUNetR.py | 13 +++- .../code_models/models/model_WNet.py | 19 ++++++ napari_cellseg3d/code_models/workers.py | 61 ++++++++----------- .../code_plugins/plugin_convert.py | 2 +- 9 files changed, 90 insertions(+), 64 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 1fc15872..35af8c76 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -15,6 +15,8 @@ def test_correct_shape_for_crf(): def test_model_list(): for model_name in MODEL_LIST.keys(): + # if model_name=="test": + # continue dims = 128 test = MODEL_LIST[model_name]( input_img_size=[dims, dims, dims], @@ -39,18 +41,20 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) - # assert res > 0 + assert 0 <= res <= 1 def test_crf(qtbot): dims = 8 mock_image = np.random.rand(1, dims, dims, dims) mock_label = np.random.rand(2, dims, dims, dims) - - crf = CRFWorker(mock_image, mock_label) + assert len(mock_label.shape) == 4 + crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) + assert len(result.shape) == 4 + assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] crf.yielded.connect(on_yield) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 66c50fba..3dafeabc 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,9 +3,10 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_inference(make_napari_viewer, qtbot): @@ -28,9 +29,9 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.setCurrentIndex(-1) # widget.start() # takes too long on Github Actions # assert widget.worker is not None diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 9adb3286..fbfe7bcc 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,9 +2,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture -from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer -from napari_cellseg3d.config import MODEL_LIST + +# from napari_cellseg3d.config import MODEL_LIST +# from napari_cellseg3d.code_models.models.model_test import TestModel def test_training(make_napari_viewer, qtbot): @@ -32,9 +33,9 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + # MODEL_LIST["test"] = TestModel() + # widget.model_choice.addItem("test") + # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) # widget.start() # assert widget.worker is not None diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 1b8dce28..21caf35f 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -57,10 +57,10 @@ def correct_shape_for_crf(image, desired_dims=4): if len(image.shape) == desired_dims: return image if len(image.shape) > desired_dims: - if image.shape[0] > 1: - raise ValueError( - f"Image shape {image.shape} might have several channels" - ) + # if image.shape[0] > 1: + # raise ValueError( + # f"Image shape {image.shape} might have several channels" + # ) image = np.squeeze(image, axis=0) if len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) @@ -200,7 +200,6 @@ def __init__( self.config = config self.log = log - # TODO(cyril) : add progress bar into log ? or do it in inference def _run_crf_job(self): """Runs the CRF post-processing step for the W-Net.""" if not CRF_INSTALLED: @@ -211,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + # labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 1bbaf659..5e759b0c 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -1,3 +1,4 @@ +import abc from dataclasses import dataclass from functools import partial from typing import List @@ -93,18 +94,19 @@ def _make_list_from_channels( raise ValueError( f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" ) - if len(image.shape) < 2: - raise ValueError( - f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" - ) if len(image.shape) == 4: image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - return [image] + elif len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) + else: + return [image] def run_method_on_channels(self, image): - image_list = self._make_list_from_channels(image) + image_list = self._make_list_from_channels(image) # FIXME rename result = np.array([self.run_method(im) for im in image_list]) return result.squeeze() @@ -607,7 +609,7 @@ def run_method(self, volume): """ method = self.methods[self.method_choice.currentText()] - return method.run_method(volume) + return method.run_method_on_channels(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index c687eac2..112c59e9 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -10,12 +10,19 @@ class SwinUNETR_(SwinUNETR): use_default_training = True weights_file = "Swin64_best_metric.pth" - def __init__(self, input_img_size, use_checkpoint=True, **kwargs): + def __init__( + self, + in_channels=1, + out_channels=1, + input_img_size=128, + use_checkpoint=True, + **kwargs, + ): try: super().__init__( input_img_size, - in_channels=1, - out_channels=1, + in_channels=in_channels, + out_channels=out_channels, feature_size=48, use_checkpoint=use_checkpoint, **kwargs, diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 4a9ff70d..86a1f7e6 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,5 +1,12 @@ +from typing import TypeVar + +from torch.nn import Module + +# local from napari_cellseg3d.code_models.models.wnet.model import WNet +T = TypeVar("T", bound="Module") + class WNet_(WNet): use_default_training = False @@ -20,6 +27,9 @@ def __init__( num_classes=num_classes, ) + def train(self: T, mode: bool = True) -> T: + raise NotImplementedError("Training not implemented for WNet") + def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. @@ -27,3 +37,12 @@ def forward(self, x): enc = self.forward_encoder(x) # dec = self.forward_decoder(enc) return enc + + def load_state_dict(self, state_dict, strict=False): + """Load the model state dict for inference, without the decoder weights.""" + encoder_checkpoint = state_dict.copy() + for k in state_dict.keys(): + if k.startswith("decoder"): + encoder_checkpoint.pop(k) + # print(encoder_checkpoint.keys()) + super().load_state_dict(encoder_checkpoint, strict=strict) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index f1ed632a..b1ead29c 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -693,17 +693,11 @@ def instance_seg( self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - - if len(semantic_labels.shape) == 4: - instance_labels = np.array( - [method.run_method(ch) for ch in semantic_labels] - ) - self.log(f"DEBUG instance results shape : {instance_labels.shape}") - else: - instance_labels = method.run_method(image=semantic_labels) + instance_labels = method.run_method_on_channels(semantic_labels) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") if self.config.filetype == "": - filetype = "" + filetype = ".tif" else: filetype = "_" + self.config.filetype @@ -903,7 +897,8 @@ def inference(self): weights = str( PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) ) - model.load_state_dict( + + model.load_state_dict( # note that this is redefined in WNet_ torch.load( weights, map_location=self.config.device, @@ -1280,7 +1275,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if self.config.do_augmentation: train_transforms = ( Compose( # TODO : figure out which ones and values ? @@ -1313,29 +1307,29 @@ def train(self): # self.log("Loading dataset...\n") def get_loader_func(num_samples): - return Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) ), - EnsureTyped(keys=["image", "label"]), - ] - ) + ), + EnsureTyped(keys=["image", "label"]), + ] + ) if do_sampling: # if there is only one volume, split samples @@ -1359,7 +1353,6 @@ def get_loader_func(num_samples): sample_loader_train = get_loader_func(num_train_samples) sample_loader_eval = get_loader_func(num_val_samples) - logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 3fa21508..d9026912 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -363,7 +363,7 @@ def _start(self): elif self.folder_choice.isChecked(): if len(self.images_filepaths) != 0: images = [ - self.instance_widgets.run_method(imread(file)) + self.instance_widgets.run_method_on_channels(imread(file)) for file in self.images_filepaths ] utils.save_folder( From 277d2b57c7dcb0cae4a4b9ce9fd53e46b8ee841e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 May 2023 15:16:25 +0200 Subject: [PATCH 537/577] Update model_WNet.py --- napari_cellseg3d/code_models/models/model_WNet.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 86a1f7e6..f07ac517 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,12 +1,6 @@ -from typing import TypeVar - -from torch.nn import Module - # local from napari_cellseg3d.code_models.models.wnet.model import WNet -T = TypeVar("T", bound="Module") - class WNet_(WNet): use_default_training = False @@ -27,8 +21,8 @@ def __init__( num_classes=num_classes, ) - def train(self: T, mode: bool = True) -> T: - raise NotImplementedError("Training not implemented for WNet") + # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): """Forward ENCODER pass of the W-Net model. From 6269c9337f8974d6cd780134fffd8075edf2df93 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 13 May 2023 10:29:39 +0200 Subject: [PATCH 538/577] Update model_VNet.py --- napari_cellseg3d/code_models/models/model_VNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 7aa6476e..41554e80 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -5,7 +5,7 @@ class VNet_(VNet): use_default_training = True weights_file = "VNet_40e.pth" - def __init__(self, in_channels=1, out_channels=2, **kwargs): + def __init__(self, in_channels=1, out_channels=1, **kwargs): try: super().__init__( in_channels=in_channels, out_channels=out_channels, **kwargs From 4d7bd24851ccfdefff397d4ccb6337c67b478071 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:51:02 +0200 Subject: [PATCH 539/577] Fixed folder creation when saving to folder --- napari_cellseg3d/code_models/crf.py | 2 +- napari_cellseg3d/code_plugins/plugin_convert.py | 8 ++++---- napari_cellseg3d/code_plugins/plugin_crf.py | 2 +- napari_cellseg3d/utils.py | 3 +++ 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 21caf35f..aa9cce75 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -210,7 +210,7 @@ def _run_crf_job(self): raise ValueError("Image and labels must have the same shape.") image = correct_shape_for_crf(image) - # labels = correct_shape_for_crf(labels) + labels = correct_shape_for_crf(labels) yield crf( image, diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index d9026912..77aa9af6 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -76,7 +76,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -175,7 +175,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -342,7 +342,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -436,7 +436,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.binarize_counter.value() if self.layer_choice: diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index 7ac605e9..d8407a0f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -238,7 +238,7 @@ def _start(self): self.result_layer = self.label_layer_loader.layer() self.result_name = self.label_layer_loader.layer_name() - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) image_list = [self.image_layer_loader.layer_data()] labels_list = [self.label_layer_loader.layer_data()] diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 1f0c17ea..8d91abaa 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -131,6 +131,9 @@ def normalize_x(image): return image / 127.5 - 1 +def mkdir_from_str(path: str, exist_ok=True, parents=True): + Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) + def mkdir_from_str(path: str, exist_ok=True, parents=True): Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) From 333a0c3b45db086e56706a70370891e856b8e5b6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 14 May 2023 11:54:07 +0200 Subject: [PATCH 540/577] Fix check_ready for results filewidget --- napari_cellseg3d/interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 2e6c4e78..e9b7a0dc 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -913,7 +913,8 @@ def required(self, is_required): else: with contextlib.suppress(TypeError): self.text_field.textChanged.disconnect(self.check_ready) - + except TypeError: + pass self.check_ready() self._required = is_required From b1d2bac4ac51218fd5e3b1e6daf99103c11bdd7f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 11:28:33 +0200 Subject: [PATCH 541/577] Added remapping in WNet + ruff config --- .pre-commit-config.yaml | 3 ++ napari_cellseg3d/code_models/workers.py | 46 ++++++++++--------------- napari_cellseg3d/utils.py | 30 +++++++++++++--- pyproject.toml | 36 +++++++++++++++++-- 4 files changed, 81 insertions(+), 34 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61ecaae5..f9fe2853 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,9 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + - id: check-toml # - repo: https://github.com/pycqa/isort # rev: 5.12.0 # hooks: diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index b1ead29c..3184156d 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -520,10 +520,9 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - if self.config.keep_on_cpu: - dataset_device = "cpu" - else: - dataset_device = self.config.device + dataset_device = ( + "cpu" if self.config.keep_on_cpu else self.config.device + ) if self.config.sliding_window_config.is_enabled(): window_size = self.config.sliding_window_config.window_size @@ -540,6 +539,7 @@ def model_output( # outputs = model(inputs) def model_output_wrapper(inputs): + inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) @@ -557,7 +557,7 @@ def model_output_wrapper(inputs): progress=True, ) except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.debug("failed to run sliding window inference") self.raise_error(e, "Error during sliding window inference") logger.debug(f"Inference output shape: {outputs.shape}") @@ -568,11 +568,9 @@ def model_output_wrapper(inputs): if post_process: out = np.array(out).astype(np.float32) out = np.squeeze(out) - return out - else: - return out + return out except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Error during sliding window inference") # sys.stdout = old_stdout # sys.stderr = old_stderr @@ -683,8 +681,7 @@ def aniso_transform(self, image): padding_mode="empty", ) return anisotropic_transform(image[0]) - else: - return image + return image def instance_seg( self, semantic_labels, image_id=0, original_filename="layer" @@ -696,10 +693,11 @@ def instance_seg( instance_labels = method.run_method_on_channels(semantic_labels) self.log(f"DEBUG instance results shape : {instance_labels.shape}") - if self.config.filetype == "": - filetype = ".tif" - else: - filetype = "_" + self.config.filetype + filetype = ( + ".tif" + if self.config.filetype == "" + else "_" + self.config.filetype + ) instance_filepath = ( self.config.results_path @@ -781,9 +779,9 @@ def stats_csv(self, instance_labels): stats = [volume_stats(c) for c in instance_labels] else: stats = [volume_stats(instance_labels)] - return stats else: - return None + stats = None + return stats except ValueError as e: self.log(f"Error occurred during stats computing : {e}") return None @@ -807,10 +805,7 @@ def inference_on_layer(self, image, model, post_process_transforms): semantic_labels=out, from_layer=True ) - if self.config.use_crf: - crf_results = self.run_crf(image, out) - else: - crf_results = None + crf_results = self.run_crf(image, out) if self.config.use_crf else None return self.create_inference_result( semantic_labels=out, @@ -992,7 +987,7 @@ def inference(self): model.to("cpu") # self.quit() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.raise_error(e, "Inference failed") self.quit() finally: @@ -1223,10 +1218,7 @@ def train(self): do_sampling = self.config.sampling - if do_sampling: - size = self.config.sample_size - else: - size = check + size = self.config.sample_size if do_sampling else check model = model_class( # FIXME check if correct input_img_size=utils.get_padding_dim(size), use_checkpoint=True @@ -1459,7 +1451,7 @@ def get_loader_func(num_samples): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") - logger.error(e, exc_info=True) + logger.exception(e) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 8d91abaa..88f7077c 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING, Union import napari import numpy as np @@ -9,6 +10,9 @@ from skimage.filters import gaussian from tifffile import imread, imwrite +if TYPE_CHECKING: + import torch + LOGGER = logging.getLogger(__name__) ############### # Global logging level setting @@ -131,9 +135,6 @@ def normalize_x(image): return image / 127.5 - 1 -def mkdir_from_str(path: str, exist_ok=True, parents=True): - Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) - def mkdir_from_str(path: str, exist_ok=True, parents=True): Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) @@ -212,6 +213,27 @@ def correct_rotation(image): return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) +def normalize_max(image): + """Normalizes an image using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image.reshape(shape) + return image + + +def remap_image( + image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 +): + """Normalizes a numpy array or Tensor using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image * (new_max - new_min) + new_min + image = image.reshape(shape) + return image + + def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, @@ -551,8 +573,6 @@ def load_images( "Loading as folder not implemented yet. Use napari to load as folder" ) # images_original = dask_imread(filename_pattern_original) - else: - images_original = imread(filename_pattern_original) # tifffile imread return imread(filename_pattern_original) # tifffile imread diff --git a/pyproject.toml b/pyproject.toml index d223072a..81d2a788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,11 +46,43 @@ where = ["."] [tool.ruff] select = [ "E", "F", "W", - "I", + "A", "B", + "G", + "I", + "PT", + "PTH", + "RET", + "SIM", + "TCH", + "NPY", ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) -ignore = ["E501", "E741"] +# and 'G004' (do not use f-strings in logging) +ignore = ["E501", "E741", "G004"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] [tool.black] line-length = 79 From cb26f7680841f735d82d06c510e859fb50528899 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 13:21:06 +0200 Subject: [PATCH 542/577] Run new hooks --- napari_cellseg3d/_tests/test_models.py | 13 +- .../_tests/test_weight_download.py | 2 +- napari_cellseg3d/code_models/crf.py | 25 ++-- .../code_models/instance_segmentation.py | 12 +- .../code_models/model_framework.py | 7 +- .../code_models/models/model_SwinUNetR.py | 2 +- .../code_models/models/model_TRAILMAP_MS.py | 2 +- .../code_models/models/model_WNet.py | 8 +- .../code_models/models/wnet/soft_Ncuts.py | 4 +- napari_cellseg3d/code_models/workers.py | 2 +- .../code_plugins/plugin_convert.py | 127 +++++++++--------- .../code_plugins/plugin_model_inference.py | 6 +- .../code_plugins/plugin_model_training.py | 10 +- .../code_plugins/plugin_review.py | 11 +- napari_cellseg3d/config.py | 8 +- .../dev_scripts/artefact_labeling.py | 16 +-- .../dev_scripts/correct_labels.py | 7 +- napari_cellseg3d/interface.py | 15 +-- pyproject.toml | 2 + 19 files changed, 142 insertions(+), 137 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35af8c76..35174b85 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -1,20 +1,23 @@ import numpy as np import torch +from numpy.random import PCG64, Generator from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss from napari_cellseg3d.config import MODEL_LIST +rand_gen = Generator(PCG64(12345)) + def test_correct_shape_for_crf(): - test = np.random.rand(1, 1, 8, 8, 8) + test = rand_gen.random(size=(1, 1, 8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) - test = np.random.rand(8, 8, 8) + test = rand_gen.random(size=(8, 8, 8)) assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) def test_model_list(): - for model_name in MODEL_LIST.keys(): + for model_name in MODEL_LIST: # if model_name=="test": # continue dims = 128 @@ -46,8 +49,8 @@ def test_soft_ncuts_loss(): def test_crf(qtbot): dims = 8 - mock_image = np.random.rand(1, dims, dims, dims) - mock_label = np.random.rand(2, dims, dims, dims) + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index 042c9524..ec8df231 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( PRETRAINED_WEIGHTS_DIR, WeightsDownloader, ) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index aa9cce75..8c311059 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,17 +54,15 @@ def correct_shape_for_crf(image, desired_dims=4): - if len(image.shape) == desired_dims: - return image if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( # f"Image shape {image.shape} might have several channels" # ) image = np.squeeze(image, axis=0) - if len(image.shape) < desired_dims: + elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) - return correct_shape_for_crf(image) + return image def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): @@ -185,8 +183,8 @@ class CRFWorker(GeneratorWorker): def __init__( self, - images_list, - labels_list, + images_list: list, + labels_list: list, config: CRFConfig = None, log=None, ): @@ -205,16 +203,19 @@ def _run_crf_job(self): if not CRF_INSTALLED: raise ImportError("pydensecrf is not installed.") - for image, labels in zip(self.images, self.labels): - if image.shape[-3:] != labels.shape[-3:]: + if len(self.images) != len(self.labels): + raise ValueError("Number of images and labels must be the same.") + + for i in range(len(self.images)): + if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - image = correct_shape_for_crf(image) - labels = correct_shape_for_crf(labels) + im = correct_shape_for_crf(self.labels[i]) + prob = correct_shape_for_crf(self.labels[i]) yield crf( - image, - labels, + im, + prob, self.config.sa, self.config.sb, self.config.sg, diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 5e759b0c..e46b64d4 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -94,16 +94,16 @@ def _make_list_from_channels( raise ValueError( f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" ) + if len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) if len(image.shape) == 4: image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - elif len(image.shape) < 2: - raise ValueError( - f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" - ) - else: return [image] + return None def run_method_on_channels(self, image): image_list = self._make_list_from_channels(image) # FIXME rename @@ -590,7 +590,7 @@ def _build(self): self._set_visibility() def _set_visibility(self): - for name in self.instance_widgets.keys(): + for name in self.instance_widgets: if name != self.method_choice.currentText(): for widget in self.instance_widgets[name]: widget.set_visibility(False) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index ddd9cd28..13598af6 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -291,7 +291,12 @@ def _load_weights_path(self): [self._default_weights_folder], file_extension="Weights file (*.pth)", ) - self._update_weights_path(file) + if file[0] == self._default_weights_folder: + return + if file is not None and file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) @staticmethod def get_device(show=True): diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 112c59e9..144317f8 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -28,7 +28,7 @@ def __init__( **kwargs, ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( input_img_size, in_channels=1, diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 66d61201..e42d54bf 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -17,7 +17,7 @@ def __init__(self, in_channels=1, out_channels=1, **kwargs): in_channels=in_channels, out_channels=out_channels, **kwargs ) except TypeError as e: - logger.warn(f"Caught TypeError: {e}") + logger.warning(f"Caught TypeError: {e}") super().__init__( in_channels=in_channels, out_channels=out_channels ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index f07ac517..7235bd61 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -28,14 +28,14 @@ def forward(self, x): """Forward ENCODER pass of the W-Net model. Done this way to allow inference on the encoder only when called by sliding_window_inference. """ - enc = self.forward_encoder(x) - # dec = self.forward_decoder(enc) - return enc + return self.forward_encoder(x) + # enc = self.forward_encoder(x) + # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" encoder_checkpoint = state_dict.copy() - for k in state_dict.keys(): + for k in state_dict: if k.startswith("decoder"): encoder_checkpoint.pop(k) # print(encoder_checkpoint.keys()) diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 4e84579f..938292c2 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -206,6 +206,7 @@ def forward(self, labels, inputs): return torch.add(torch.neg(loss), K) """ + return None def gaussian_kernel(self, radius, sigma): """Computes the Gaussian kernel. @@ -348,5 +349,4 @@ def get_weights(self, inputs): 1, 1, self.W_X.shape[0], self.W_X.shape[1] ) # (1, 1, H*W*D, H*W*D) - W = torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) - return W + return torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 3184156d..1aaa05cc 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -54,7 +54,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.crf import crf_with_config -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( ImageStats, volume_stats, ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 77aa9af6..4357e51e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -7,7 +7,7 @@ import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceWidgets, clear_small_objects, threshold, @@ -98,18 +98,19 @@ def _start(self): f"isotropic_{layer.name}", ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - utils.resize(np.array(imread(file)), zoom) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -193,18 +194,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -274,18 +276,19 @@ def _start(self): utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): @@ -360,18 +363,19 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.instance_widgets.run_method_on_channels(imread(file)) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.instance_widgets.run_method_on_channels(imread(file)) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -454,18 +458,19 @@ def _start(self): utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - utils.save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + utils.save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 0233ac23..50c0fd49 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -10,12 +10,12 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceMethod, InstanceWidgets, ) -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.workers import ( InferenceResult, InferenceWorker, ) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 0670c1c7..339bc631 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -32,7 +32,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( TrainingReport, TrainingWorker, ) @@ -419,9 +419,7 @@ def check_ready(self): * False and displays a warning if not """ - if self.images_filepaths != [] and self.labels_filepaths != []: - return True - else: + if self.images_filepaths == [] and self.labels_filepaths != []: logger.warning("Image and label paths are not correctly set") return False return True @@ -1060,7 +1058,7 @@ def on_yield(self, report: TrainingReport): self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1228,7 +1226,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e, exc_info=True) + logger.exception(e) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 235595e4..dd98bcd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -178,11 +178,10 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None: - if cfg.image.shape != cfg.labels.shape: - logger.warning( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: + logger.warning( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): @@ -400,7 +399,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e, exc_info=True) + logger.exception(e) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 7250fe78..af42d779 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -6,7 +6,7 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.instance_segmentation import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ @@ -89,9 +89,9 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info( - "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) - ) + logger.info("Model list :") + for model_name in MODEL_LIST: + logger.info(f" * {model_name}") return MODEL_LIST.keys() diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index b4712aec..93746eb6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ -import os +import os # TODO(cyril): remove os +from pathlib import Path import napari import numpy as np @@ -6,7 +7,7 @@ from skimage.filters import threshold_otsu from tifffile import imread, imwrite -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -289,18 +290,13 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): ndarray Label image with artefacts labelled and small artefacts removed. """ - if not is_labeled: - # find all the connected components in the artefacts image - labels = ndimage.label(artefacts)[0] - else: - labels = artefacts + labels = ndimage.label(artefacts)[0] if not is_labeled else artefacts # remove the small components labels_i, counts = np.unique(labels, return_counts=True) labels_i = labels_i[counts > min_size] labels_i = labels_i[labels_i > 0] - artefacts = np.where(np.isin(labels, labels_i), labels, 0) - return artefacts + return np.where(np.isin(labels, labels_i), labels, 0) def create_artefact_labels( @@ -388,7 +384,7 @@ def create_artefact_labels_from_folder( path_labels.sort() path_images.sort() # create the output folder - os.makedirs(path + "/artefact_neurons", exist_ok=True) + Path().mkdir(path + "/artefact_neurons", exist_ok=True) # create the artefact labels for i in range(len(path_images)): print(path_labels[i]) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 9862c3fa..4a7363b2 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -12,7 +12,7 @@ from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) @@ -228,10 +228,7 @@ def relabel( print("these labels will be added") if test: viewer.close() - if viewer is None: - viewer = napari.view_image(image) - else: - viewer = viewer + viewer = napari.view_image(image) if viewer is None else viewer if not test: viewer.add_labels(artefact_copy, name="labels added") napari.run() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index e9b7a0dc..2d84cda9 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -735,8 +735,8 @@ def anisotropy_zoom_factor(aniso_res): """ - base = max(aniso_res) - return [res / base for res in aniso_res] + base = min(aniso_res) + return [base / res for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" @@ -810,7 +810,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: logger.warning("Please select a valid layer !") - return + return None return self.layer().data @@ -913,8 +913,7 @@ def required(self, is_required): else: with contextlib.suppress(TypeError): self.text_field.textChanged.disconnect(self.check_ready) - except TypeError: - pass + self.check_ready() self._required = is_required @@ -1051,7 +1050,7 @@ def make_n_spinboxes( boxes = [] for _i in range(n): - box = class_(min, max, default, step, parent, fixed) + box = class_(min_value, max_value, default, step, parent, fixed) boxes.append(box) return boxes @@ -1225,7 +1224,7 @@ def open_file_dialog( default_path = utils.parse_default_path(possible_paths) return QFileDialog.getOpenFileName( - widget, "Choose file", default_path, file_extension + widget, "Choose file", default_path, filetype ) @@ -1236,7 +1235,7 @@ def open_folder_dialog( default_path = utils.parse_default_path(possible_paths) logger.info(f"Default : {default_path}") - filenames = QFileDialog.getExistingDirectory( + return QFileDialog.getExistingDirectory( widget, "Open directory", default_path + "/.." ) diff --git a/pyproject.toml b/pyproject.toml index 81d2a788..7210af6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ exclude = [ "dist", "node_modules", "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", ] [tool.black] From 526e7bac06801f5c579bb550cd084a621affe59e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:06:24 +0200 Subject: [PATCH 543/577] Small docs update --- docs/res/code/plugin_convert.rst | 15 ------- docs/res/code/utils.rst | 4 -- docs/res/guides/custom_model_template.rst | 25 +++++++++++- napari_cellseg3d/code_models/workers.py | 49 +---------------------- 4 files changed, 26 insertions(+), 67 deletions(-) diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 03944510..25006d0f 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -28,18 +28,3 @@ ThresholdUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ThresholdUtils :members: __init__ - -Functions ------------------------------------ - -save_folder -***************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_folder - -save_layer -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_layer - -show_result -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index e90ee7e0..d9fdcfa2 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -62,7 +62,3 @@ denormalize_y load_images ************************************** .. autofunction:: napari_cellseg3d.utils::load_images - -format_Warning -************************************** -.. autofunction:: napari_cellseg3d.utils::format_Warning diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index 35d21137..a70df29b 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -6,7 +6,30 @@ Advanced : Declaring a custom model .. warning:: **WIP** : Adding new models is still a work in progress and will likely not work simply by adding the model in the plugin. + Please `file an issue`_ if you would like to add a custom model and we will help you get it working. + +To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder:: + + class ModelTemplate_(ABC): # replace ABC with your PyTorch model class name + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) # download URL goes in pretrained_models.json + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass + + .. note:: **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute -:: +.. _file an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 1aaa05cc..b1114766 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -61,16 +61,6 @@ logger = utils.LOGGER -""" -Writing something to log messages from outside the main thread is rather problematic (plenty of silent crashes...) -so instead, following the instructions in the guides below to have a worker with custom signals, I implemented -a custom worker function.""" - -# FutureReference(): -# https://python-forum.io/thread-31349.html -# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ -# https://napari-staging-site.github.io/guides/stable/threading.html - PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( "models/pretrained" ) @@ -188,9 +178,9 @@ def safe_extract( class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_ + Separate from Worker instances as indicated `on this post`_ - .. _here: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect """ # TODO link ? log_signal = Signal(str) @@ -210,41 +200,6 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files -class ONNXModelWrapper(torch.nn.Module): - """Class to replace torch model by ONNX Runtime session""" - - def __init__(self, file_location): - super().__init__() - try: - import onnxruntime as ort - except ImportError as e: - logger.error("ONNX is not installed but ONNX model was loaded") - logger.error(e) - msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" - logger.error(msg) - raise ImportError(msg) from e - - self.ort_session = ort.InferenceSession( - file_location, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], - ) - - def forward(self, modeL_input): - """Wraps ONNX output in a torch tensor""" - outputs = self.ort_session.run( - None, {"input": modeL_input.cpu().numpy()} - ) - return torch.tensor(outputs[0]) - - def eval(self): - """Dummy function to replace model.eval()""" - pass - - def to(self, device): - """Dummy function to replace model.to(device)""" - pass - - @dataclass class InferenceResult: """Class to record results of a segmentation job""" From 1dc6715417fad29d4c3436a8f8bcc10dec3de78a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:24:43 +0200 Subject: [PATCH 544/577] Testing fix --- napari_cellseg3d/code_models/instance_segmentation.py | 5 ++--- napari_cellseg3d/code_models/models/model_WNet.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index e46b64d4..6037a733 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -102,11 +102,10 @@ def _make_list_from_channels( image = np.squeeze(image) if len(image.shape) == 4: return [im for im in image] - return [image] - return None + return [image] def run_method_on_channels(self, image): - image_list = self._make_list_from_channels(image) # FIXME rename + image_list = self._make_list_from_channels(image) result = np.array([self.run_method(im) for im in image_list]) return result.squeeze() diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 7235bd61..cb5ef6d8 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -21,7 +21,7 @@ def __init__( num_classes=num_classes, ) - # def train(self: T, mode: bool = True) -> T: # FIXME makes inference raise NotImplementedError + # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") def forward(self, x): From 4fe2e6d388fbcdb970b5853007f858c0b919f85d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 14:59:05 +0200 Subject: [PATCH 545/577] Fixed multithread testing (locally) --- .github/workflows/test_and_deploy.yml | 1 + napari_cellseg3d/_tests/test_models.py | 14 ++-- .../_tests/test_plugin_inference.py | 29 ++++---- napari_cellseg3d/_tests/test_training.py | 30 ++++---- .../code_plugins/plugin_model_inference.py | 69 +------------------ .../code_plugins/plugin_model_training.py | 63 +---------------- 6 files changed, 44 insertions(+), 162 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index f230c9ec..e9a66ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -9,6 +9,7 @@ on: - main - npe2 - cy/voronoi-otsu + - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 35174b85..4852f651 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -52,7 +52,7 @@ def test_crf(qtbot): mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) assert len(mock_label.shape) == 4 - crf = CRFWorker([mock_image, mock_image], [mock_label, mock_label]) + crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): assert isinstance(result, np.ndarray) @@ -60,20 +60,20 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=True ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() mock_image = mock_image[0] mock_label = mock_label[0] crf = CRFWorker(mock_image, mock_label) - crf.yielded.connect(on_yield) - crf.start() with qtbot.waitSignal( - signal=crf.finished, timeout=60000, raising=False + signal=crf.finished, timeout=20000, raising=False ) as blocker: blocker.connect(crf.errored) + crf.yielded.connect(on_yield) + crf.start() diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 3dafeabc..d1264218 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,10 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_inference(make_napari_viewer, qtbot): @@ -29,14 +28,16 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.setCurrentIndex(-1) - - # widget.start() # takes too long on Github Actions - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: - # blocker.connect(widget.worker.errored) - - #### assert len(viewer.layers) == 2 + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.setCurrentIndex(-1) + + widget.worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(widget.config) + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() # takes too long on Github Actions + assert widget.worker is not None + # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index fbfe7bcc..4d558363 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -2,10 +2,9 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_training import Trainer - -# from napari_cellseg3d.config import MODEL_LIST -# from napari_cellseg3d.code_models.models.model_test import TestModel +from napari_cellseg3d.config import MODEL_LIST def test_training(make_napari_viewer, qtbot): @@ -33,18 +32,19 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - # MODEL_LIST["test"] = TestModel() - # widget.model_choice.addItem("test") - # widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) - - # widget.start() - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds. - # blocker.connect(widget.worker.errored) - # widget.worker.error_signal.connect(on_error) - # widget.worker.train() - # assert widget.worker is not None + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) + + worker_config = widget._set_worker_config() + widget.worker = widget._create_worker_from_config(worker_config) + + with qtbot.waitSignal( + signal=widget.worker.finished, timeout=10000, raising=True + ) as blocker: + blocker.connect(widget.worker.errored) + widget.worker.start() + assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 50c0fd49..30a7786c 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -580,66 +580,7 @@ def start(self): self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - self.model_info = config.ModelInfo( - name=self.model_choice.currentText(), - model_input_size=self.model_input_size.value(), - ) - - self.weights_config.custom = self.custom_weights_choice.isChecked() - - save_path = self.results_filewidget.text_field.text() - if not self._check_results_path(save_path): - msg = f"ERROR: please set valid results path. Current path is {save_path}" - self.log.print_and_log(msg) - logger.warning(msg) - else: - if self.results_path is None: - self.results_path = save_path - - zoom_config = config.Zoom( - enabled=self.anisotropy_wdgt.enabled(), - zoom_values=self.anisotropy_wdgt.scaling_xyz(), - ) - thresholding_config = config.Thresholding( - enabled=self.thresholding_checkbox.isChecked(), - threshold_value=self.thresholding_slider.slider_value, - ) - - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[ - self.instance_widgets.method_choice.currentText() - ], - ) - - self.post_process_config = config.PostProcessConfig( - zoom=zoom_config, - thresholding=thresholding_config, - instance=self.instance_config, - ) - - if self.window_infer_box.isChecked(): - size = int(self.window_size_choice.currentText()) - window_config = config.SlidingWindowConfig( - window_size=size, - window_overlap=self.window_overlap_slider.slider_value, - ) - else: - window_config = config.SlidingWindowConfig() - - self.worker_config = config.InferenceWorkerConfig( - device=self.get_device(), - model_info=self.model_info, - weights_config=self.weights_config, - results_path=self.results_path, - filetype=self.filetype_choice.currentText(), - keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), - compute_stats=self.save_stats_to_csv_box.isChecked(), - post_process_config=self.post_process_config, - sliding_window_config=window_config, - use_crf=self.use_crf.isChecked(), - crf_config=self.crf_widgets.make_config(), - ) + self._set_worker_config() ##################### ##################### ##################### @@ -681,12 +622,8 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config( - self, worker_config: config.InferenceWorkerConfig - ): - if isinstance(worker_config, config.InfererConfig): - raise TypeError("Please provide a valid worker config object") - return InferenceWorker(worker_config=worker_config) + def _create_worker_from_config(self, config: config.InferenceWorkerConfig): + return InferenceWorker(worker_config=config) def _set_worker_config(self) -> config.InferenceWorkerConfig: self.model_info = config.ModelInfo( diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 339bc631..b76e9da2 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -810,63 +810,12 @@ def start(self): self.data = None raise err - model_config = config.ModelInfo( - name=self.model_choice.currentText() - ) - - self.weights_config.path = self.weights_config.path - self.weights_config.custom = self.custom_weights_choice.isChecked() - self.weights_config.use_pretrained = ( - not self.use_transfer_choice.isChecked() - ) - - deterministic_config = config.DeterministicConfig( - enabled=self.use_deterministic_choice.isChecked(), - seed=self.box_seed.value(), - ) - - validation_percent = ( - self.validation_percent_choice.slider_value / 100 - ) - - results_path_folder = Path( - self.results_path - + f"/{model_config.name}_{utils.get_date_time()}" - ) - Path(results_path_folder).mkdir( - parents=True, exist_ok=False - ) # avoid overwrite where possible - - patch_size = [w.value() for w in self.patch_size_widgets] - - logger.debug("Loading config...") - self.worker_config = config.TrainingWorkerConfig( - device=self.get_device(), - model_info=model_config, - weights_info=self.weights_config, - train_data_dict=self.data, - validation_percent=validation_percent, - max_epochs=self.epoch_choice.value(), - loss_function=self.get_loss(self.loss_choice.currentText()), - learning_rate=float(self.learning_rate_choice.currentText()), - scheduler_patience=self.scheduler_patience_choice.value(), - scheduler_factor=self.scheduler_factor_choice.slider_value, - validation_interval=self.val_interval_choice.value(), - batch_size=self.batch_choice.slider_value, - results_path_folder=str(results_path_folder), - sampling=self.patch_choice.isChecked(), - num_samples=self.sample_choice_slider.slider_value, - sample_size=patch_size, - do_augmentation=self.augment_choice.isChecked(), - deterministic_config=deterministic_config, - ) # TODO(cyril) continue to put params in config - self.config = config.TrainerConfig( save_as_zip=self.zip_choice.isChecked() ) self._set_worker_config() - self.worker = TrainingWorker(worker_config=self.worker_config) + self.worker = TrainingWorker(config=self.worker_config) self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] @@ -894,14 +843,8 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config( - self, worker_config: config.TrainingWorkerConfig - ): - if isinstance(config, config.TrainerConfig): - raise TypeError( - "Expected a TrainingWorkerConfig, got a TrainerConfig" - ) - return TrainingWorker(worker_config=worker_config) + def _create_worker_from_config(self, config: config.TrainingWorkerConfig): + return TrainingWorker(config=config) def _set_worker_config(self) -> config.TrainingWorkerConfig: model_config = config.ModelInfo(name=self.model_choice.currentText()) From 259823ad2504f3018bf7b6699fd4764a9e9a236c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:06:02 +0200 Subject: [PATCH 546/577] Added proper tests for train/infer --- .../_tests/test_plugin_inference.py | 36 ++++++++++++++----- napari_cellseg3d/_tests/test_training.py | 34 ++++++++++++------ .../code_plugins/plugin_model_inference.py | 8 +++-- .../code_plugins/plugin_model_training.py | 10 ++++-- 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index d1264218..04305082 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -4,7 +4,10 @@ from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.code_plugins.plugin_model_inference import ( + InferenceResult, + Inferer, +) from napari_cellseg3d.config import MODEL_LIST @@ -28,16 +31,31 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(widget.config) - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() # takes too long on Github Actions - assert widget.worker is not None + assert widget.worker_config is not None + assert widget.model_info is not None + worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None + assert worker.config.model_info is not None + worker.config.layer = viewer.layers[0].data + assert worker.config.layer is not None + worker.log_parameters() + + res = next(worker.inference()) + assert isinstance(res, InferenceResult) + assert res.result.shape == (6, 6, 6) + + # def on_error(e): + # print(e) + # assert False + # with qtbot.waitSignal( + # signal=worker.finished, timeout=10000, raising=True + # ) as blocker: + # worker.error_signal.connect(on_error) + # blocker.connect(worker.errored) + # worker.inference() # takes too long on Github Actions # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 4d558363..080df419 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -3,7 +3,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d.code_plugins.plugin_model_training import ( + Trainer, + TrainingReport, +) from napari_cellseg3d.config import MODEL_LIST @@ -32,19 +35,30 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) worker_config = widget._set_worker_config() - widget.worker = widget._create_worker_from_config(worker_config) - - with qtbot.waitSignal( - signal=widget.worker.finished, timeout=10000, raising=True - ) as blocker: - blocker.connect(widget.worker.errored) - widget.worker.start() - assert widget.worker is not None + worker = widget._create_worker_from_config(worker_config) + worker.config.train_data_dict = [{"image": im_path, "label": im_path}] + worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.log_parameters() + res = next(worker.train()) + + assert isinstance(res, TrainingReport) + + # def on_error(e): + # print(e) + # assert False + # + # with qtbot.waitSignal( + # signal=widget.worker.finished, timeout=10000, raising=True + # ) as blocker: + # blocker.connect(widget.worker.errored) + # widget.worker.error_signal.connect(on_error) + # widget.worker.train() + # assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 30a7786c..00aff5f5 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -622,8 +622,12 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.InferenceWorkerConfig): - return InferenceWorker(worker_config=config) + def _create_worker_from_config( + self, worker_config: config.InferenceWorkerConfig + ): + if isinstance(worker_config, config.InfererConfig): + raise TypeError("Please provide a valid worker config object") + return InferenceWorker(worker_config=worker_config) def _set_worker_config(self) -> config.InferenceWorkerConfig: self.model_info = config.ModelInfo( diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index b76e9da2..f00923e3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -843,8 +843,14 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") - def _create_worker_from_config(self, config: config.TrainingWorkerConfig): - return TrainingWorker(config=config) + def _create_worker_from_config( + self, worker_config: config.TrainingWorkerConfig + ): + if isinstance(config, config.TrainerConfig): + raise TypeError( + "Expected a TrainingWorkerConfig, got a TrainerConfig" + ) + return TrainingWorker(worker_config=worker_config) def _set_worker_config(self) -> config.TrainingWorkerConfig: model_config = config.ModelInfo(name=self.model_choice.currentText()) From 8d30d5ca395b7052a46a91292b53ca33a2d3d07b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:31:36 +0200 Subject: [PATCH 547/577] Slight coverage increase --- napari_cellseg3d/_tests/test_plugin_inference.py | 13 ++----------- napari_cellseg3d/_tests/test_training.py | 1 + 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 04305082..c437ac83 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -39,23 +39,14 @@ def test_inference(make_napari_viewer, qtbot): assert widget.worker_config is not None assert widget.model_info is not None worker = widget._create_worker_from_config(widget.worker_config) + assert worker.config is not None assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data + worker.config.post_process_config.instance.enabled = True assert worker.config.layer is not None worker.log_parameters() res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) - - # def on_error(e): - # print(e) - # assert False - # with qtbot.waitSignal( - # signal=worker.finished, timeout=10000, raising=True - # ) as blocker: - # worker.error_signal.connect(on_error) - # blocker.connect(worker.errored) - # worker.inference() # takes too long on Github Actions - # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 080df419..e7f1e07b 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -43,6 +43,7 @@ def test_training(make_napari_viewer, qtbot): worker = widget._create_worker_from_config(worker_config) worker.config.train_data_dict = [{"image": im_path, "label": im_path}] worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.config.max_epochs = 1 worker.log_parameters() res = next(worker.train()) From e83847df6441bfa71a5f77f5d61948808a6b9705 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 16 May 2023 17:45:47 +0200 Subject: [PATCH 548/577] Update test_plugin_inference.py --- napari_cellseg3d/_tests/test_plugin_inference.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index c437ac83..ca8e84d4 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,6 +3,9 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.instance_segmentation import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.models.model_test import TestModel from napari_cellseg3d.code_plugins.plugin_model_inference import ( InferenceResult, @@ -44,6 +47,10 @@ def test_inference(make_napari_viewer, qtbot): assert worker.config.model_info is not None worker.config.layer = viewer.layers[0].data worker.config.post_process_config.instance.enabled = True + worker.config.post_process_config.instance.method = ( + INSTANCE_SEGMENTATION_METHOD_LIST["Watershed"]() + ) + assert worker.config.layer is not None worker.log_parameters() From 25fe3d7cd10a13df7951d9d05228ee02cf349af3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 17 May 2023 11:41:39 +0200 Subject: [PATCH 549/577] Set window inference to 64 for WNet --- napari_cellseg3d/code_plugins/plugin_model_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 00aff5f5..bdd2f123 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -171,7 +171,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(3) # set to 64 by default + self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -196,7 +196,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) - self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## From 644038e547ca6a99a54c9509ff9f9023bcfb9c0c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 20 May 2023 09:22:52 +0200 Subject: [PATCH 550/577] Moved normalization to the correct place --- napari_cellseg3d/code_models/workers.py | 2 +- napari_cellseg3d/utils.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index b1114766..c5e1552e 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -492,9 +492,9 @@ def model_output( logger.debug(f"inputs type : {inputs.dtype}") try: # outputs = model(inputs) + inputs = utils.remap_image(inputs) def model_output_wrapper(inputs): - inputs = utils.remap_image(inputs) result = model(inputs) return post_process_transforms(result) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 88f7077c..663872c4 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -223,12 +223,18 @@ def normalize_max(image): def remap_image( - image: Union["np.ndarray", "torch.Tensor"], new_max=100, new_min=0 + image: Union["np.ndarray", "torch.Tensor"], + new_max=100, + new_min=0, + prev_max=None, + prev_min=None, ): """Normalizes a numpy array or Tensor using the max and min value""" shape = image.shape image = image.flatten() - image = (image - image.min()) / (image.max() - image.min()) + im_max = prev_max if prev_max is not None else image.max() + im_min = prev_min if prev_min is not None else image.min() + image = (image - im_min) / (im_max - im_min) image = image * (new_max - new_min) + new_min image = image.reshape(shape) return image From 14b3516f481bd019b9784b23838ca834fd6c8f6a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 11:09:48 +0200 Subject: [PATCH 551/577] Added auto-set dims for cropping --- napari_cellseg3d/code_plugins/plugin_crop.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 0c2e4042..ab552904 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -3,6 +3,7 @@ import napari import numpy as np from magicgui import magicgui +from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -43,9 +44,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") - self.image_layer_loader.layer_list.currentIndexChanged.connect( - self.auto_set_dims - ) + self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -141,12 +140,10 @@ def auto_set_dims(self): logger.debug(self.image_layer_loader.layer_name()) data = self.image_layer_loader.layer_data() if data is not None: - logger.debug(f"auto_set_dims : {data.shape}") + logger.debug("auto_set_dims : {}".format(data.shape)) if len(data.shape) == 3: for i, box in enumerate(self.crop_size_widgets): - logger.debug( - f"setting dim {i} to {floor(data.shape[i]/2)}" - ) + logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") box.setValue(floor(data.shape[i] / 2)) def _build(self): From e3ea954f33cfd5a09e805db12bafbc18390b3134 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 24 May 2023 12:19:37 +0200 Subject: [PATCH 552/577] Update test_plugin_utils.py --- napari_cellseg3d/_tests/test_plugin_utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 0f183fa4..60c25ccc 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,7 +1,5 @@ -from pathlib import Path - import numpy as np -from tifffile import imread +from numpy.random import PCG64, Generator from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, @@ -15,10 +13,9 @@ def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) - im_path = str(Path(__file__).resolve().parent / "res/test.tif") - image = imread(im_path) - view.add_image(image) - view.add_labels(image.astype(np.uint8)) + image = rand_gen.random((10, 10, 10)).astype(np.uint8) + image_layer = view.add_image(image, name="image") + label_layer = view.add_labels(image.astype(np.uint8), name="labels") view.window.add_dock_widget(widget) view.dims.ndisplay = 3 @@ -32,4 +29,6 @@ def test_utils_plugin(make_napari_viewer): menu = widget.utils_widgets[i].instance_widgets.method_choice menu.setCurrentIndex(menu.currentIndex() + 1) + assert len(image_layer.data.shape) == 3 + assert len(label_layer.data.shape) == 3 widget.utils_widgets[i]._start() From 6d47bb2687ddaa799ca98e654c146bbf97994f58 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 15:50:18 +0200 Subject: [PATCH 553/577] More WNet - Added experimental .pt loading for jit models - More CRF tests - Optimized WNet by loading inference only --- napari_cellseg3d/_tests/test_models.py | 61 ++++++++++++------ napari_cellseg3d/code_models/crf.py | 8 ++- .../code_models/model_framework.py | 2 +- .../code_models/models/model_WNet.py | 18 +++--- .../code_models/models/wnet/model.py | 19 ++++-- napari_cellseg3d/code_models/workers.py | 62 ++++++++++--------- napari_cellseg3d/code_plugins/plugin_crop.py | 12 ++-- .../dev_scripts/correct_labels.py | 12 ++-- pyproject.toml | 1 + 9 files changed, 120 insertions(+), 75 deletions(-) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 4852f651..c67b3cab 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -2,9 +2,14 @@ import torch from numpy.random import PCG64, Generator -from napari_cellseg3d.code_models.crf import CRFWorker, correct_shape_for_crf +from napari_cellseg3d.code_models.crf import ( + CRFWorker, + correct_shape_for_crf, + crf_batch, + crf_with_config, +) from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss -from napari_cellseg3d.config import MODEL_LIST +from napari_cellseg3d.config import MODEL_LIST, CRFConfig rand_gen = Generator(PCG64(12345)) @@ -47,7 +52,38 @@ def test_soft_ncuts_loss(): assert 0 <= res <= 1 -def test_crf(qtbot): +def test_crf_batch(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_batch( + np.array([mock_image, mock_image, mock_image]), + np.array([mock_label, mock_label, mock_label]), + sa=config.sa, + sb=config.sb, + sg=config.sg, + w1=config.w1, + w2=config.w2, + ) + + assert isinstance(result, np.ndarray) + assert result.shape == (3, 2, dims, dims, dims) + + +def test_crf_config(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_with_config(mock_image, mock_label, config) + assert isinstance(result, np.ndarray) + assert result.shape == mock_label.shape + + +def test_crf_worker(qtbot): dims = 8 mock_image = rand_gen.random(size=(1, dims, dims, dims)) mock_label = rand_gen.random(size=(2, dims, dims, dims)) @@ -60,20 +96,5 @@ def on_yield(result): assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=True - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() - - mock_image = mock_image[0] - mock_label = mock_label[0] - - crf = CRFWorker(mock_image, mock_label) - with qtbot.waitSignal( - signal=crf.finished, timeout=20000, raising=False - ) as blocker: - blocker.connect(crf.errored) - crf.yielded.connect(on_yield) - crf.start() + result = next(crf._run_crf_job()) + on_yield(result) diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 8c311059..b362246a 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -54,6 +54,8 @@ def correct_shape_for_crf(image, desired_dims=4): + logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}") + logger.debug(f"Image shape: {image.shape}") if len(image.shape) > desired_dims: # if image.shape[0] > 1: # raise ValueError( @@ -62,6 +64,7 @@ def correct_shape_for_crf(image, desired_dims=4): image = np.squeeze(image, axis=0) elif len(image.shape) < desired_dims: image = np.expand_dims(image, axis=0) + logger.debug(f"Corrected image shape: {image.shape}") return image @@ -210,9 +213,12 @@ def _run_crf_job(self): if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: raise ValueError("Image and labels must have the same shape.") - im = correct_shape_for_crf(self.labels[i]) + im = correct_shape_for_crf(self.images[i]) prob = correct_shape_for_crf(self.labels[i]) + logger.debug(f"image shape : {im.shape}") + logger.debug(f"labels shape : {prob.shape}") + yield crf( im, prob, diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 13598af6..b244f84d 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -289,7 +289,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - file_extension="Weights file (*.pth)", + filetype="Weights file (*.pth, *.pt)", ) if file[0] == self._default_weights_folder: return diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index cb5ef6d8..62142e73 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -1,8 +1,8 @@ # local -from napari_cellseg3d.code_models.models.wnet.model import WNet +from napari_cellseg3d.code_models.models.wnet.model import WNet_encoder -class WNet_(WNet): +class WNet_(WNet_encoder): use_default_training = False weights_file = "wnet.pth" @@ -24,13 +24,13 @@ def __init__( # def train(self: T, mode: bool = True) -> T: # raise NotImplementedError("Training not implemented for WNet") - def forward(self, x): - """Forward ENCODER pass of the W-Net model. - Done this way to allow inference on the encoder only when called by sliding_window_inference. - """ - return self.forward_encoder(x) - # enc = self.forward_encoder(x) - # return self.forward_decoder(enc) + # def forward(self, x): + # """Forward ENCODER pass of the W-Net model. + # Done this way to allow inference on the encoder only when called by sliding_window_inference. + # """ + # return self.forward_encoder(x) + # # enc = self.forward_encoder(x) + # # return self.forward_decoder(enc) def load_state_dict(self, state_dict, strict=False): """Load the model state dict for inference, without the decoder weights.""" diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 585ea0dd..a23084d0 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -16,6 +16,19 @@ ] +class WNet_encoder(nn.Module): + """WNet with encoder only.""" + + def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + super().__init__() + self.device = device + self.encoder = UNet(device, in_channels, num_classes, encoder=True) + + def forward(self, x): + """Forward pass of the W-Net model.""" + return self.forward_encoder(x) + + class WNet(nn.Module): """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. The model performs unsupervised segmentation of 3D images. @@ -36,13 +49,11 @@ def forward(self, x): def forward_encoder(self, x): """Forward pass of the encoder part of the W-Net model.""" - enc = self.encoder(x) - return enc + return self.encoder(x) def forward_decoder(self, enc): """Forward pass of the decoder part of the W-Net model.""" - dec = self.decoder(enc) - return dec + return self.decoder(enc) class UNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c5e1552e..c4521599 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -820,41 +820,43 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config - - # try: - self.log("Instantiating model...") - model = model_class( # FIXME test if works - input_img_size=[dims, dims, dims], - device=self.config.device, - num_classes=self.config.model_info.num_classes, - ) - # try: - model = model.to(self.config.device) - # except Exception as e: - # self.raise_error(e, "Issue loading model to device") - # logger.debug(f"model : {model}") - if model is None: - raise ValueError("Model is None") + if Path(weights_config.path).suffix == ".pt": + model = torch.jit.load(weights_config.path) # try: - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path else: - self.downloader.download_weights( - model_name, - model_class.weights_file, - ) - weights = str( - PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[dims, dims, dims], + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, + ) + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + ) - model.load_state_dict( # note that this is redefined in WNet_ - torch.load( - weights, - map_location=self.config.device, + model.load_state_dict( # note that this is redefined in WNet_ + torch.load( + weights, + map_location=self.config.device, + ) ) - ) - self.log("Done") + self.log("Done") # except Exception as e: # self.raise_error(e, "Issue loading weights") # except Exception as e: diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index ab552904..74691e1f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,9 +1,9 @@ +from math import floor from pathlib import Path import napari import numpy as np from magicgui import magicgui -from math import floor # Qt from qtpy.QtWidgets import QSizePolicy @@ -44,7 +44,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") - self.image_layer_loader.layer_list.currentIndexChanged.connect(self.auto_set_dims) + self.image_layer_loader.layer_list.currentIndexChanged.connect( + self.auto_set_dims + ) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) @@ -140,10 +142,12 @@ def auto_set_dims(self): logger.debug(self.image_layer_loader.layer_name()) data = self.image_layer_loader.layer_data() if data is not None: - logger.debug("auto_set_dims : {}".format(data.shape)) + logger.debug(f"auto_set_dims : {data.shape}") if len(data.shape) == 3: for i, box in enumerate(self.crop_size_widgets): - logger.debug(f"setting dim {i} to {floor(data.shape[i]/2)}") + logger.debug( + f"setting dim {i} to {floor(data.shape[i]/2)}" + ) box.setValue(floor(data.shape[i] / 2)) def _build(self): diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 4a7363b2..f413812d 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -363,9 +363,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif") -# -# image_path = str(im_path / "volumes/images.tif") -# gt_labels_path = str(im_path / "labels/testing_im.tif") -# relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) +if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/somatomotor") + + image_path = str(im_path / "volumes/c1images.tif") + gt_labels_path = str(im_path / "labels/c1labels.tif") + relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) diff --git a/pyproject.toml b/pyproject.toml index 7210af6e..87cc2e1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ docs = [ test = [ "pytest", "pytest_qt", + "pytest-cov", "coverage", "tox", "twine", From 8fd582da790b9152b1b5045bc514e57276bab339 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:12:07 +0200 Subject: [PATCH 554/577] Update crf test/deps for testing --- .github/workflows/test_and_deploy.yml | 2 +- napari_cellseg3d/_tests/test_models.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index e9a66ae2..0911e358 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions -# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index c67b3cab..ec7462db 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -68,7 +68,6 @@ def test_crf_batch(): w2=config.w2, ) - assert isinstance(result, np.ndarray) assert result.shape == (3, 2, dims, dims, dims) @@ -79,7 +78,6 @@ def test_crf_config(): config = CRFConfig() result = crf_with_config(mock_image, mock_label, config) - assert isinstance(result, np.ndarray) assert result.shape == mock_label.shape @@ -91,7 +89,6 @@ def test_crf_worker(qtbot): crf = CRFWorker([mock_image], [mock_label]) def on_yield(result): - assert isinstance(result, np.ndarray) assert len(result.shape) == 4 assert len(mock_label.shape) == 4 assert result.shape[-3:] == mock_label.shape[-3:] From ced0422ff573e14f8e21e0c7fa0502b6cff51f4f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:20:30 +0200 Subject: [PATCH 555/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0911e358..d09be5f0 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - python -m pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -87,6 +86,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From ba51551184f98ca83e0ae28fce7deb097b4be6f8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:34:33 +0200 Subject: [PATCH 556/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d09be5f0..d36e03a3 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,6 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions + pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox @@ -86,7 +87,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -U setuptools setuptools_scm wheel twine build - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf - name: Build and publish env: TWINE_USERNAME: __token__ From 2deb7a8e3ce3c1ea65b5f5cefe7731d3d0ada8ec Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:28 +0200 Subject: [PATCH 557/577] Update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index bb02bb56..dbc84bb5 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf : git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] ; opencv-python extras = crf From 8820c2b47a4a7ccc7597e8b734531cea2fe4be4c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:42:45 +0200 Subject: [PATCH 558/577] Update test_and_deploy.yml --- .github/workflows/test_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index d36e03a3..60bc5505 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox From 3a03e27d69dd375c573165f1571dd8dd315e405e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 May 2023 16:50:44 +0200 Subject: [PATCH 559/577] Trying to fix tox install of pydensecrf --- .github/workflows/test_and_deploy.yml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 60bc5505..e9a66ae2 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions -# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/tox.ini b/tox.ini index dbc84bb5..0605fc8c 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = magicgui pytest-qt qtpy - pydensecrf: git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf + git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] ; opencv-python extras = crf From dcd1f7e0bcf7ab5c951163a09e6f6ac42f72aa53 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:23:51 +0200 Subject: [PATCH 560/577] Added experimental ONNX support for inference --- .../code_models/model_framework.py | 7 +--- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 34 ++++++++++++++++++- .../code_plugins/plugin_model_inference.py | 4 +-- pyproject.toml | 8 +++++ 5 files changed, 45 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index b244f84d..636746a2 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -291,12 +291,7 @@ def _load_weights_path(self): [self._default_weights_folder], filetype="Weights file (*.pth, *.pt)", ) - if file[0] == self._default_weights_folder: - return - if file is not None and file[0] != "": - self.weights_config.path = file[0] - self.weights_filewidget.text_field.setText(file[0]) - self._default_weights_folder = str(Path(file[0]).parent) + self._update_weights_path(file) @staticmethod def get_device(show=True): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index a23084d0..f98829bb 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -26,7 +26,7 @@ def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): def forward(self, x): """Forward pass of the W-Net model.""" - return self.forward_encoder(x) + return self.encoder(x) class WNet(nn.Module): diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c4521599..cae5ed8e 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,6 +199,34 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files +class ONNXModelWrapper(torch.nn.Module): + """Class to replace torch model if ONNX is used""" + def __init__(self, file_location): + super().__init__() + try: + import onnx + import onnxruntime as ort + except ImportError as e: + logger.error("ONNX is not installed but ONNX model was loaded") + logger.error(e) + msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" + logger.error(msg) + raise ImportError(msg) + + self.ort_session = ort.InferenceSession( + file_location, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + + def forward(self, modeL_input): + outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + return torch.tensor(outputs[0]) + + def eval(self): + return True + + def to(self, device): + return True @dataclass class InferenceResult: @@ -821,9 +849,13 @@ def inference(self): weights_config = self.config.weights_config post_process_config = self.config.post_process_config if Path(weights_config.path).suffix == ".pt": + self.log("Instantiating PyTorch jit model...") model = torch.jit.load(weights_config.path) # try: - else: + elif Path(weights_config.path).suffix == ".onnx": + self.log("Instantiating ONNX model...") + model = ONNXModelWrapper(weights_config.path) + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index bdd2f123..599ec5b3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING - +from pathlib import Path import numpy as np import pandas as pd @@ -356,7 +356,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - file_extension="Weights file (*.pth *.pt *.onnx)", + filetype="Weights file (*.pth, *.pt, *.onnx)", ) self._update_weights_path(file) diff --git a/pyproject.toml b/pyproject.toml index 87cc2e1d..2783761e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,3 +118,11 @@ test = [ "twine", "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] +onnx-cpu = [ + "onnx", + "onnxruntime" +] +onnx-gpu = [ + "onnx", + "onnxruntime-gpu" +] From 3ba51f7933699844b99433dda092cc72e01b4956 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:47:48 +0200 Subject: [PATCH 561/577] Updated WNet for ONNX conversion --- .../code_models/models/wnet/model.py | 59 +++++++++++-------- napari_cellseg3d/code_models/workers.py | 9 ++- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index f98829bb..23584b30 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -59,18 +59,33 @@ def forward_decoder(self, enc): class UNet(nn.Module): """Half of the W-Net model, based on the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, encoder=True): + def __init__( + self, device, in_channels, out_channels, encoder=True, dropout=0.65 + ): super(UNet, self).__init__() self.device = device - self.in_b = InBlock(device, in_channels, 64) - self.conv1 = Block(device, 64, 128) - self.conv2 = Block(device, 128, 256) - self.conv3 = Block(device, 256, 512) - self.bot = Block(device, 512, 1024) - self.deconv1 = Block(device, 1024, 512) - self.deconv2 = Block(device, 512, 256) - self.deconv3 = Block(device, 256, 128) - self.out_b = OutBlock(device, 128, out_channels) + self.max_pool = nn.MaxPool3d(2) + self.in_b = InBlock(device, in_channels, 64, dropout=dropout) + self.conv1 = Block(device, 64, 128, dropout=dropout) + self.conv2 = Block(device, 128, 256, dropout=dropout) + self.conv3 = Block(device, 256, 512, dropout=dropout) + self.bot = Block(device, 512, 1024, dropout=dropout) + self.deconv1 = Block(device, 1024, 512, dropout=dropout) + self.conv_trans1 = nn.ConvTranspose3d( + 1024, 512, 2, stride=2, device=self.device + ) + self.deconv2 = Block(device, 512, 256, dropout=dropout) + self.conv_trans2 = nn.ConvTranspose3d( + 512, 256, 2, stride=2, device=self.device + ) + self.deconv3 = Block(device, 256, 128, dropout=dropout) + self.conv_trans3 = nn.ConvTranspose3d( + 256, 128, 2, stride=2, device=self.device + ) + self.out_b = OutBlock(device, 128, out_channels, dropout=dropout) + self.conv_trans_out = nn.ConvTranspose3d( + 128, 64, 2, stride=2, device=self.device + ) self.sm = nn.Softmax(dim=1).to(device) self.encoder = encoder @@ -78,17 +93,15 @@ def __init__(self, device, in_channels, out_channels, encoder=True): def forward(self, x): """Forward pass of the U-Net model.""" in_b = self.in_b(x.to(self.device)) - c1 = self.conv1(nn.MaxPool3d(2)(in_b)) - c2 = self.conv2(nn.MaxPool3d(2)(c1)) - c3 = self.conv3(nn.MaxPool3d(2)(c2)) - x = self.bot(nn.MaxPool3d(2)(c3)) + c1 = self.conv1(self.max_pool(in_b)) + c2 = self.conv2(self.max_pool(c1)) + c3 = self.conv3(self.max_pool(c2)) + x = self.bot(self.max_pool(c3)) x = self.deconv1( torch.cat( [ c3, - nn.ConvTranspose3d( - 1024, 512, 2, stride=2, device=self.device - )(x), + self.conv_trans1(x), ], dim=1, ) @@ -97,9 +110,7 @@ def forward(self, x): torch.cat( [ c2, - nn.ConvTranspose3d( - 512, 256, 2, stride=2, device=self.device - )(x), + self.conv_trans2(x), ], dim=1, ) @@ -108,9 +119,7 @@ def forward(self, x): torch.cat( [ c1, - nn.ConvTranspose3d( - 256, 128, 2, stride=2, device=self.device - )(x), + self.conv_trans3(x), ], dim=1, ) @@ -119,9 +128,7 @@ def forward(self, x): torch.cat( [ in_b, - nn.ConvTranspose3d( - 128, 64, 2, stride=2, device=self.device - )(x), + self.conv_trans_out(x), ], dim=1, ) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index cae5ed8e..bb5a113d 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -200,7 +200,7 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files class ONNXModelWrapper(torch.nn.Module): - """Class to replace torch model if ONNX is used""" + """Class to replace torch model by ONNX Runtime session""" def __init__(self, file_location): super().__init__() try: @@ -219,14 +219,17 @@ def __init__(self, file_location): ) def forward(self, modeL_input): + """Wraps ONNX output in a torch tensor""" outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) return torch.tensor(outputs[0]) def eval(self): - return True + """Dummy function to replace model.eval()""" + pass def to(self, device): - return True + """Dummy function to replace model.to(device)""" + pass @dataclass class InferenceResult: From 99e7e2aacc4f7d8135740446cfabb9a23569ebb2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 May 2023 10:56:45 +0200 Subject: [PATCH 562/577] Added dropout param --- .../code_models/models/wnet/model.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 23584b30..3416acb1 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -141,17 +141,17 @@ def forward(self, x): class InBlock(nn.Module): """Input block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(InBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -163,19 +163,19 @@ def forward(self, x): class Block(nn.Module): """Basic block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(Block, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), nn.Conv3d(in_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), nn.Conv3d(out_channels, out_channels, 1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(out_channels, device=device), ).to(device) @@ -187,21 +187,21 @@ def forward(self, x): class OutBlock(nn.Module): """Output block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels): + def __init__(self, device, in_channels, out_channels, dropout=0.65): super(OutBlock, self).__init__() self.device = device self.module = nn.Sequential( nn.Conv3d(in_channels, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, 64, 3, padding=1, device=device), nn.ReLU(), - nn.Dropout(p=0.65), + nn.Dropout(p=dropout), nn.BatchNorm3d(64, device=device), nn.Conv3d(64, out_channels, 1, device=device), ).to(device) def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) + return self.module(x.to(self.device)) \ No newline at end of file From 8678bfbf9e9002608b8e67056efd717083257ae1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 31 May 2023 16:13:42 +0200 Subject: [PATCH 563/577] Minor fixes in training --- napari_cellseg3d/code_models/workers.py | 7 +++---- napari_cellseg3d/code_plugins/plugin_model_training.py | 2 +- napari_cellseg3d/interface.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index bb5a113d..c67ea523 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1538,7 +1538,6 @@ def get_loader_func(num_samples): val_data["image"].to(device), val_data["label"].to(device), ) - self.log("Performing validation...") try: with torch.no_grad(): val_outputs = sliding_window_inference( @@ -1607,8 +1606,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1621,7 +1620,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index f00923e3..3e666dcc 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -815,7 +815,7 @@ def start(self): ) self._set_worker_config() - self.worker = TrainingWorker(config=self.worker_config) + self.worker = TrainingWorker(worker_config=self.worker_config) self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 2d84cda9..de4d3206 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1236,7 +1236,7 @@ def open_folder_dialog( logger.info(f"Default : {default_path}") return QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + "/.." + widget, "Open directory", default_path # + "/.." ) From 858fe7e11061adf7153d583fe289e8bf71d48fbb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:31:23 +0200 Subject: [PATCH 564/577] Fix weights file extension in inference + coverage - Remove unused scripts - More tests - Fixed weights type in inference --- .gitignore | 1 + .../_tests/test_labels_correction.py | 8 +- .../_tests/test_plugin_inference.py | 2 + napari_cellseg3d/_tests/test_utils.py | 18 ++- .../code_models/model_framework.py | 2 +- .../code_models/models/wnet/crf.py | 112 -------------- napari_cellseg3d/code_plugins/plugin_crf.py | 6 +- .../code_plugins/plugin_model_inference.py | 8 +- .../dev_scripts/evaluate_labels.py | 2 +- .../extract_extra_channels_labels.py | 144 ------------------ napari_cellseg3d/interface.py | 4 +- 11 files changed, 34 insertions(+), 273 deletions(-) delete mode 100644 napari_cellseg3d/code_models/models/wnet/crf.py delete mode 100644 napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py diff --git a/.gitignore b/.gitignore index df67a187..7460d861 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,4 @@ notebooks/instance_test.ipynb !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif +cov.syspath.txt diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..b4f13238 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -37,16 +37,16 @@ def test_correct_labels(): ) -def test_relabel(make_napari_viewer): - viewer = make_napari_viewer() +def test_relabel(): cl.relabel( str(image_path), str(labels_path), go_fast=True, - viewer=viewer, test=True, ) def test_evaluate_model_performance(): - el.evaluate_model_performance(labels, labels, print_details=True) + el.evaluate_model_performance( + labels, labels, print_details=True, visualize=False + ) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index ca8e84d4..1ae83102 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -57,3 +57,5 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (6, 6, 6) + + widget.on_yield(res) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 48550747..dc680b35 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,5 +1,5 @@ -import os from functools import partial +from pathlib import Path import numpy as np import torch @@ -9,7 +9,7 @@ def test_fill_list_in_between(): - list = [1, 2, 3, 4, 5, 6] + test_list = [1, 2, 3, 4, 5, 6] res = [ 1, "", @@ -35,7 +35,7 @@ def test_fill_list_in_between(): fill = partial(utils.fill_list_in_between, n=2, fill_value="") - assert fill(list) == res + assert fill(test_list) == res def test_align_array_sizes(): @@ -110,8 +110,16 @@ def test_normalize_x(): def test_parse_default_path(): - user_path = os.path.expanduser("~") - assert utils.parse_default_path([None]) == user_path + user_path = Path().home() + assert utils.parse_default_path([None]) == str(user_path) + + test_path = "C:/test/test" + path = [test_path, None, None] + assert utils.parse_default_path(path) == test_path + + long_path = "D:/very/long/path/what/a/bore/ifonlytherewassomethingtohelpmenottypeitiallthetime" + path = [test_path, None, None, long_path, ""] + assert utils.parse_default_path(path) == long_path def test_thread_test(make_napari_viewer): diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 636746a2..ddd9cd28 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -289,7 +289,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt)", + file_extension="Weights file (*.pth)", ) self._update_weights_path(file) diff --git a/napari_cellseg3d/code_models/models/wnet/crf.py b/napari_cellseg3d/code_models/models/wnet/crf.py deleted file mode 100644 index 004db3a1..00000000 --- a/napari_cellseg3d/code_models/models/wnet/crf.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Implements the CRF post-processing step for the W-Net. -Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - -Also uses research from: -Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials -Philipp Krähenbühl and Vladlen Koltun -NIPS 2011 - -Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. -""" - -import numpy as np -import pydensecrf.densecrf as dcrf -from pydensecrf.utils import ( - create_pairwise_bilateral, - create_pairwise_gaussian, - unary_from_softmax, -) - -__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" -__credits__ = [ - "Yves Paychère", - "Colin Hofmann", - "Cyril Achard", - "Philipp Krähenbühl", - "Vladlen Koltun", - "Liang-Chieh Chen", - "George Papandreou", - "Iasonas Kokkinos", - "Kevin Murphy", - "Alan L. Yuille", - "Xide Xia", - "Brian Kulis", - "Lucas Beyer", -] - - -def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): - """CRF post-processing step for the W-Net, applied to a batch of images. - - Args: - images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. - probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. - """ - - return np.stack( - [ - crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) - for i in range(images.shape[0]) - ], - axis=0, - ) - - -def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): - """Implements the CRF post-processing step for the W-Net. - Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. - Implemented using the pydensecrf library. - - Args: - image (np.ndarray): Array of shape (C, H, W, D) containing the input image. - prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. - sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. - sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. - sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. - - Returns: - np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. - """ - d = dcrf.DenseCRF( - image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] - ) - # print(f"Image shape : {image.shape}") - # print(f"Prob shape : {prob.shape}") - # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels - - # Get unary potentials from softmax probabilities - U = unary_from_softmax(prob) - d.setUnaryEnergy(U) - - # Generate pairwise potentials - featsGaussian = create_pairwise_gaussian( - sdims=(sg, sg, sg), shape=image.shape[1:] - ) # image.shape) - featsBilateral = create_pairwise_bilateral( - sdims=(sa, sa, sa), - schan=tuple([sb for i in range(image.shape[0])]), - img=image, - chdim=-1, - ) - - # Add pairwise potentials to the CRF - compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( - [1 for i in range(prob.shape[0])] - # , dtype=np.float32 - ) - d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) - d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) - - # Run inference - Q = d.inference(n_iter) - - return np.array(Q).reshape( - (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) - ) diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py index d8407a0f..76194e87 100644 --- a/napari_cellseg3d/code_plugins/plugin_crf.py +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial from pathlib import Path @@ -277,7 +278,10 @@ def _on_start(self): def _on_finish(self): self.worker = None - self.start_button.setText("Start") + with contextlib.suppress(RuntimeError): + self.start_button.setText("Start") + + # should only happen when testing def _on_error(self, error): logger.error(error) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 599ec5b3..256cffa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,6 +1,6 @@ from functools import partial from typing import TYPE_CHECKING -from pathlib import Path + import numpy as np import pandas as pd @@ -171,7 +171,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_size_choice = ui.DropdownMenu( sizes_window, text_label="Window size" ) - self.window_size_choice.setCurrentIndex(self._default_window_size) # set to 64 by default + self.window_size_choice.setCurrentIndex( + self._default_window_size + ) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -356,7 +358,7 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth, *.pt, *.onnx)", + file_extension="Weights file (*.pth *.pt *.onnx)", ) self._update_weights_path(file) diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index ee9919b6..2830f4e7 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -127,7 +127,7 @@ def evaluate_model_performance( ) if visualize: - viewer = napari.Viewer() + viewer = napari.Viewer(ndisplay=3) viewer.add_labels(labels, name="ground truth") viewer.add_labels(model_labels, name="model's labels") found_model = np.where( diff --git a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py b/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py deleted file mode 100644 index 70ee10b6..00000000 --- a/napari_cellseg3d/dev_scripts/extract_extra_channels_labels.py +++ /dev/null @@ -1,144 +0,0 @@ -import numpy as np -from skimage.filters import threshold_otsu -from skimage.segmentation import expand_labels -from tqdm import tqdm - - -def extract_labels_from_channels( # TODO add separate channels results - nuclei_labels: np.array, - extra_channels: list, - radius: int = 4, - threshold_factor=2, - viewer=None, -): - """ - Attemps to extract labels from other channels by expanding nuclei labels and picking the one with most pixels around it. - Args: - nuclei_labels (np.array): labels for the nuclei - extra_channels (list): channels arrays to extract labels from - radius: radius in which the approximation is made - - Returns: - A list of extracted labels for each extra channel - """ - labeled_channels = [] - contrasted_channels = [] - for channel in extra_channels: - channel = (channel - np.min(channel)) / ( - np.max(channel) - np.min(channel) - ) - threshold_brightness = threshold_otsu(channel) * threshold_factor - channel_contrasted = np.where( - channel > threshold_brightness, channel, 0 - ) - contrasted_channels.append(channel_contrasted) - if viewer is not None: - viewer.add_image( - channel_contrasted, - name="channel_contrasted", - colormap="viridis", - ) - for label_id in tqdm(np.unique(nuclei_labels)): - if label_id == 0: - continue - label_nucleus = np.where(nuclei_labels == label_id, nuclei_labels, 0) - expanded = expand_labels(label_nucleus, distance=radius) - restricted = np.where(expanded != 0, nuclei_labels, 0) - overlap = np.where(restricted != label_id, restricted, 0) - - for i, channel in enumerate(contrasted_channels): - label_contrasted = np.where(expanded != 0, channel, 0) - if overlap.any() != 0: - max_labeled = 0 - for overlap_id in np.unique(overlap): - if overlap_id == 0: - continue - assigned_pixels = np.count_nonzero( - np.where(overlap == overlap_id, channel, 0) - ) - if assigned_pixels > max_labeled: - max_labeled = assigned_pixels - max_label_id = overlap_id - if label_id != max_label_id: - labeled_channels.append( - np.zeros_like(label_contrasted) - ) - else: - labeled_channel = np.where(label_contrasted != 0, label_id, 0) - labeled_channels.append(labeled_channel) - if ( - np.count_nonzero(labeled_channel) > 0 - and viewer is not None - ): - viewer.add_labels( - labeled_channel, name=f"label_{label_id}_channel_{i+1}" - ) - - cat_labels = np.zeros_like(nuclei_labels) - for labels in np.unique(labeled_channels): - if labels == 0: - continue - cat_labels += np.where(labels != 0, labels, 0) - return cat_labels - - -if __name__ == "__main__": - from pathlib import Path - - import napari - from tifffile import imread - - image_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - ) - # image_path = Path.home() / "Desktop/Code/WNet-benchmark/results/showcase/WNet-labels-Voronoi-Otsu.tif" - nuclei_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/results/showcase/ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__DAPI_only.tif" - ) - extra_channels_path = ( - Path.home() - / "Desktop/Code/WNet-benchmark/dataset/wyss_data/batch_1/tmp" - ) - extra_channels = [ - imread(str(path)) - for path in extra_channels_path.glob( - "ELAST9_5_DAPI_Iba1_CD163_crop2_denoised__*.tif" - ) - ] - labels = imread(str(image_path)) - viewer = napari.Viewer() - - shift = 0 - viewer.add_image( - imread(str(nuclei_path))[ - shift : 32 + shift, shift : 32 + shift, shift : 32 + shift - ], - name="nuclei", - ) - viewer.add_labels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - [ - viewer.add_image( - channel[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - ) - for channel in extra_channels - ] - - labeled_channels = extract_labels_from_channels( - labels[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift], - [ - c[shift : 32 + shift, shift : 32 + shift, shift : 32 + shift] - for c in extra_channels - ], - radius=4, - viewer=viewer, - ) - - viewer.add_labels(labeled_channels) - # [viewer.add_labels(item, name=key) for key, item in labeled_channels.items()] - # expanded = expand_labels(labels, 4) - # viewer.add_labels(expanded) - napari.run() diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index de4d3206..06a2190a 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1208,7 +1208,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, possible_paths: list = (), - filetype: str = "Image file (*.tif *.tiff)", + file_extension: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1224,7 +1224,7 @@ def open_file_dialog( default_path = utils.parse_default_path(possible_paths) return QFileDialog.getOpenFileName( - widget, "Choose file", default_path, filetype + widget, "Choose file", default_path, file_extension ) From 41b5ba49d47995314515fa5ff9056cb5ceef1deb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 10:41:07 +0200 Subject: [PATCH 565/577] Run all hooks --- .../_tests/test_plugin_inference.py | 5 ++++- .../code_models/models/model_TRAILMAP.py | 9 ++++----- .../code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_models/workers.py | 20 +++++++++++-------- napari_cellseg3d/dev_scripts/thread_test.py | 5 ++++- pyproject.toml | 2 +- 6 files changed, 26 insertions(+), 17 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1ae83102..1e486c14 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,9 +34,12 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() + widget.model_choice.setCurrentIndex(-1) + assert widget.window_infer_box.isChecked() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + widget.model_choice.setCurrentIndex(-1) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 8c7f3b70..e6bbad55 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -44,7 +44,7 @@ def forward(self, x): # print(out.shape) def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -57,7 +57,7 @@ def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): ) def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -69,7 +69,7 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): ) def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - decode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -84,10 +84,9 @@ def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): ) def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - out = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) - return out class TRAILMAP_(TRAILMAP): diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 3416acb1..2900b89c 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -204,4 +204,4 @@ def __init__(self, device, in_channels, out_channels, dropout=0.65): def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) \ No newline at end of file + return self.module(x.to(self.device)) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index c67ea523..245e6f02 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -199,28 +199,31 @@ def __init__(self): # TODO(cyril): move inference and training workers to separate files + class ONNXModelWrapper(torch.nn.Module): """Class to replace torch model by ONNX Runtime session""" + def __init__(self, file_location): super().__init__() try: - import onnx import onnxruntime as ort except ImportError as e: logger.error("ONNX is not installed but ONNX model was loaded") logger.error(e) msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" logger.error(msg) - raise ImportError(msg) + raise ImportError(msg) from e self.ort_session = ort.InferenceSession( file_location, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) def forward(self, modeL_input): """Wraps ONNX output in a torch tensor""" - outputs = self.ort_session.run(None, {'input': modeL_input.cpu().numpy()}) + outputs = self.ort_session.run( + None, {"input": modeL_input.cpu().numpy()} + ) return torch.tensor(outputs[0]) def eval(self): @@ -231,6 +234,7 @@ def to(self, device): """Dummy function to replace model.to(device)""" pass + @dataclass class InferenceResult: """Class to record results of a segmentation job""" @@ -858,7 +862,7 @@ def inference(self): elif Path(weights_config.path).suffix == ".onnx": self.log("Instantiating ONNX model...") model = ONNXModelWrapper(weights_config.path) - else: # assume is .pth + else: # assume is .pth self.log("Instantiating model...") model = model_class( # FIXME test if works input_img_size=[dims, dims, dims], @@ -1606,8 +1610,8 @@ def get_loader_func(num_samples): yield train_report weights_filename = ( - f"{model_name}_best_metric" - + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1620,7 +1624,7 @@ def get_loader_func(num_samples): / Path( weights_filename, ), - ) + ) self.log("Saving complete") self.log( f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}" diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index b8dbc442..a48f6db0 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -2,6 +2,7 @@ import napari from napari.qt.threading import thread_worker +from numpy.random import PCG64, Generator from qtpy.QtWidgets import ( QGridLayout, QLabel, @@ -12,6 +13,8 @@ QWidget, ) +rand_gen = Generator(PCG64(12345)) + @thread_worker def two_way_communication_with_args(start, end): @@ -128,7 +131,7 @@ def on_finish(): if __name__ == "__main__": - viewer = napari.view_image(np.random.rand(512, 512)) + viewer = napari.view_image(rand_gen.random(512, 512)) w = create_connected_widget(viewer) viewer.window.add_dock_widget(w) diff --git a/pyproject.toml b/pyproject.toml index 2783761e..f71ddb23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ select = [ ] # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) # and 'G004' (do not use f-strings in logging) -ignore = ["E501", "E741", "G004"] +ignore = ["E501", "E741", "G004", "A003"] exclude = [ ".bzr", ".direnv", From 0463e2ebb2aebdd63207d58cc4b6129585270449 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 11:27:58 +0200 Subject: [PATCH 566/577] Fix inference testing --- .../_tests/test_plugin_inference.py | 13 +++++++----- .../code_models/models/model_test.py | 20 +++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 1e486c14..779f5094 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -34,12 +34,15 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - widget.model_choice.setCurrentIndex(-1) + widget.model_choice.setCurrentText("WNet") + widget._restrict_window_size_for_model() assert widget.window_infer_box.isChecked() + assert widget.window_size_choice.currentText() == "64" - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.model_choice.setCurrentIndex(-1) + test_model_name = "test" + MODEL_LIST[test_model_name] = TestModel + widget.model_choice.addItem(test_model_name) + widget.model_choice.setCurrentText(test_model_name) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None @@ -59,6 +62,6 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) - assert res.result.shape == (6, 6, 6) + assert res.result.shape == (8, 8, 8) widget.on_yield(res) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 1cb52f06..28f3a05b 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -20,13 +20,13 @@ def forward(self, x): # return val_inputs -# if __name__ == "__main__": -# -# model = TestModel() -# model.train() -# model.zero_grad() -# from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR -# torch.save( -# model.state_dict(), -# PRETRAINED_WEIGHTS_DIR + f"/{get_weights_file()}" -# ) +if __name__ == "__main__": + model = TestModel() + model.train() + model.zero_grad() + from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR + + torch.save( + model.state_dict(), + PRETRAINED_WEIGHTS_DIR + f"/{TestModel.weights_file}", + ) From b8bc533939911623d09ef4c5a9e5a9acca7ede37 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 2 Jun 2023 13:45:50 +0200 Subject: [PATCH 567/577] Changed anisotropy calculation --- napari_cellseg3d/_tests/test_interface.py | 7 ++++++- napari_cellseg3d/interface.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py index 840f7a93..08e0e675 100644 --- a/napari_cellseg3d/_tests/test_interface.py +++ b/napari_cellseg3d/_tests/test_interface.py @@ -1,7 +1,6 @@ from napari_cellseg3d.interface import AnisotropyWidgets, Log - def test_log(qtbot): log = Log() log.print_and_log("test") @@ -13,3 +12,9 @@ def test_log(qtbot): assert log.toPlainText() == "\ntest2" qtbot.add_widget(log) + + +def test_zoom_factor(): + resolution = [10.0, 10.0, 5.0] + zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution) + assert zoom == [1, 1, 0.5] diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 06a2190a..d2ec5789 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -735,8 +735,8 @@ def anisotropy_zoom_factor(aniso_res): """ - base = min(aniso_res) - return [base / res for res in aniso_res] + base = max(aniso_res) + return [res / base for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" From 560afc93779971f5494179527ccb7923fe49856b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 12:12:28 +0200 Subject: [PATCH 568/577] Fixed aniso correction and CRF interaction --- napari_cellseg3d/code_models/workers.py | 17 ++++++++++++++--- napari_cellseg3d/interface.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/workers.py b/napari_cellseg3d/code_models/workers.py index 245e6f02..50f85395 100644 --- a/napari_cellseg3d/code_models/workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -724,7 +724,12 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): instance_labels, stats = self.get_instance_result(out, i=i) if self.config.use_crf: try: - crf_results = self.run_crf(inputs, out, image_id=i) + crf_results = self.run_crf( + inputs, + out, + aniso_transform=self.aniso_transform, + image_id=i, + ) except ValueError as e: self.log(f"Error occurred during CRF : {e}") @@ -746,8 +751,10 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): i=i, ) - def run_crf(self, image, labels, image_id=0): + def run_crf(self, image, labels, aniso_transform, image_id=0): try: + if aniso_transform is not None: + image = aniso_transform(image) crf_results = crf_with_config( image, labels, config=self.config.crf_config, log=self.log ) @@ -795,7 +802,11 @@ def inference_on_layer(self, image, model, post_process_transforms): semantic_labels=out, from_layer=True ) - crf_results = self.run_crf(image, out) if self.config.use_crf else None + crf_results = ( + self.run_crf(image, out, aniso_transform=self.aniso_transform) + if self.config.use_crf + else None + ) return self.create_inference_result( semantic_labels=out, diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d2ec5789..014c17b6 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -667,7 +667,7 @@ def __init__( w.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self.box_widgets_lbl = [ - make_label("Resolution in " + axis + " (microns) :", parent=parent) + make_label("Pixel size in " + axis + " (microns) :", parent=parent) for axis in "xyz" ] From 91e923be580e84d5156ccd3547b8e1dafc139fcc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 10 Jun 2023 12:20:04 +0200 Subject: [PATCH 569/577] Remove duplicate tests --- .github/workflows/test_and_deploy.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index e9a66ae2..105c260a 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -7,17 +7,12 @@ on: push: branches: - main - - npe2 - - cy/voronoi-otsu - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - - npe2 - - cy/voronoi-otsu - - cy/wnet workflow_dispatch: jobs: From 04eae7efd0f13d2c9641465e375a001eaccf45cf Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sun, 11 Jun 2023 13:31:16 +0200 Subject: [PATCH 570/577] Finish rebase + changed step to auto in spinbox --- .github/workflows/test_and_deploy.yml | 1 - docs/res/welcome.rst | 2 +- napari_cellseg3d/__init__.py | 2 +- .../code_models/instance_segmentation.py | 8 +------- .../code_models/models/model_SwinUNetR.py | 1 - .../code_models/models/model_TRAILMAP_MS.py | 1 - napari_cellseg3d/code_plugins/plugin_helper.py | 2 +- napari_cellseg3d/interface.py | 16 +++++++++------- pyproject.toml | 3 ++- setup.cfg | 2 +- 10 files changed, 16 insertions(+), 22 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 105c260a..406bf4f5 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -7,7 +7,6 @@ on: push: branches: - main - - cy/wnet tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 045297f6..892549a8 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -113,7 +113,7 @@ This plugin mainly uses the following libraries and software: .. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets .. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype - +.. _WNet model: https://arxiv.org/abs/1711.08506 .. rubric:: References diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 6e2681e8..be8123e4 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc1" +__version__ = "0.0.3rc1" diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 6037a733..dc530fa8 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -9,19 +9,13 @@ from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes from tifffile import imread +# local from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis -# from skimage.measure import marching_cubes -# from skimage.measure import mesh_surface_area - - # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 144317f8..2d7b5ef6 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -4,7 +4,6 @@ logger = LOGGER -from napari_cellseg3d.utils import LOGGER class SwinUNETR_(SwinUNETR): use_default_training = True diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index e42d54bf..baf8635d 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -3,7 +3,6 @@ logger = LOGGER -logger = LOGGER class TRAILMAP_MS_(UNet3D): use_default_training = True diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 54c34a8f..552f70ea 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -39,7 +39,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.3rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 014c17b6..a73ea62d 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -13,6 +13,7 @@ from qtpy.QtCore import QObject, Qt, QUrl from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor from qtpy.QtWidgets import ( + QAbstractSpinBox, QCheckBox, QComboBox, QDoubleSpinBox, @@ -1085,13 +1086,14 @@ def __init__( if text_label is not None: self.label = make_label(name=text_label) - self.valueChanged.connect(self._update_step) - - def _update_step(self): - if self.value() < 0.9: - self.setSingleStep(0.01) - else: - self.setSingleStep(0.1) + # self.valueChanged.connect(self._update_step) + self.setStepType(QAbstractSpinBox.StepType.AdaptiveDecimalStepType) + + # def _update_step(self): + # if self.value() <= 1: + # self.setSingleStep(0.1) + # else: + # self.setSingleStep(1) @property def tooltips(self): diff --git a/pyproject.toml b/pyproject.toml index f71ddb23..e39a7522 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "napari_cellseg3d" -version = "0.0.2rc6" +version = "0.0.3rc1" authors = [ {name = "Cyril Achard", email = "cyril.achard@epfl.ch"}, {name = "Maxime Vidal", email = "maxime.vidal@epfl.ch"}, @@ -102,6 +102,7 @@ dev = [ "black", "ruff", "pre-commit", + "tuna", ] docs = [ "sphinx", diff --git a/setup.cfg b/setup.cfg index f3294b60..8ee82f96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.2rc6 +version = 0.0.3rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu From 7c6e3c8dd1e9f0c35c5de14e254275d514885055 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 12 Jun 2023 12:10:24 +0200 Subject: [PATCH 571/577] Updated based on feedback from CYHSM Co-Authored-By: Markus Frey <5563464+CYHSM@users.noreply.github.com> --- .../code_models/instance_segmentation.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 455812e1..afdf3c15 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -52,6 +52,18 @@ def __init__( self.function = function self.counters: List[ui.DoubleIncrementCounter] = [] self.sliders: List[ui.Slider] = [] + self._setup_widgets( + num_counters, num_sliders, widget_parent=widget_parent + ) + + def _setup_widgets(self, num_counters, num_sliders, widget_parent=None): + """Initializes the needed widgets for the instance segmentation method, adding sliders and counters to the + instance segmentation widget. + Args: + num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + num_sliders: Number of Slider UI elements needed to set the parameters of the function + widget_parent: parent for the declared widgets + """ if num_sliders > 0: for i in range(num_sliders): widget = f"slider_{i}" @@ -154,6 +166,8 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Original code at : + https://github.com/clEsperanto/pyclesperanto_prototype/blob/master/pyclesperanto_prototype/_tier9/_voronoi_otsu_labeling.py Args: volume (np.ndarray): volume to segment From 41a2194bdc4f69c5838314ba698ca6756a4fb942 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 10 Jul 2023 10:33:34 +0200 Subject: [PATCH 572/577] Added minimal WNet notebook for training --- napari_cellseg3d/_tests/test_models.py | 4 +- .../code_models/instance_segmentation.py | 34 +- .../code_models/models/model_VNet.py | 17 - .../code_models/models/wnet/model.py | 129 ++- .../code_models/models/wnet/soft_Ncuts.py | 169 +-- .../code_models/models/wnet/train_wnet.py | 1008 +++++++++++++++++ notebooks/train_wnet.ipynb | 267 +++++ 7 files changed, 1393 insertions(+), 235 deletions(-) create mode 100644 napari_cellseg3d/code_models/models/wnet/train_wnet.py create mode 100644 notebooks/train_wnet.ipynb diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index ec7462db..3845bb6a 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -42,8 +42,8 @@ def test_soft_ncuts_loss(): loss = SoftNCutsLoss( data_shape=[dims, dims, dims], device="cpu", - o_i=4, - o_x=4, + intensity_sigma=4, + spatial_sigma=4, radius=2, ) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index afdf3c15..535bd429 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -313,23 +313,23 @@ def clear_small_objects(image, threshold, is_file_path=False): return result -def to_instance(image, is_file_path=False): - """Converts a **ground-truth** label to instance (unique id per object) labels. Does not remove small objects. - - Args: - image: image or path to image - is_file_path: if True, will consider ``image`` to be a string containing a path to a file, if not treats it as an image data array. - - Returns: resulting converted labels - - """ - if is_file_path: - image = [imread(image)] - # image = image.compute() - - return binary_watershed( - image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 - ) # FIXME add params from utils plugin +# def to_instance(image, is_file_path=False): +# """Converts a **ground-truth** label to instance (unique id per object) labels. Does not remove small objects. +# +# Args: +# image: image or path to image +# is_file_path: if True, will consider ``image`` to be a string containing a path to a file, if not treats it as an image data array. +# +# Returns: resulting converted labels +# +# """ +# if is_file_path: +# image = [imread(image)] +# image = image.compute() +# +# return binary_watershed( +# image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 +# ) def to_semantic(image, is_file_path=False): diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 41554e80..b082ccab 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -14,20 +14,3 @@ def __init__(self, in_channels=1, out_channels=1, **kwargs): super().__init__( in_channels=in_channels, out_channels=out_channels ) - - # def get_output(self, input): - # out = self(input) - # return out - - # def get_validation(self, val_inputs): # FIXME standardize - # roi_size = (64, 64, 64) - # sw_batch_size = 1 - # val_outputs = sliding_window_inference( - # val_inputs, - # roi_size, - # sw_batch_size, - # self, - # # mode="gaussian", - # # overlap=0.7, - # ) - # return val_outputs diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 2900b89c..e1dfbec8 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -3,6 +3,8 @@ The model performs unsupervised segmentation of 3D images. """ +from typing import List + import torch import torch.nn as nn @@ -22,7 +24,11 @@ class WNet_encoder(nn.Module): def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): super().__init__() self.device = device - self.encoder = UNet(device, in_channels, num_classes, encoder=True) + self.encoder = UNet( + in_channels=in_channels, + out_channels=out_channels, + encoder=True, + ) def forward(self, x): """Forward pass of the W-Net model.""" @@ -35,17 +41,25 @@ class WNet(nn.Module): It first encodes the input image into a latent space using the U-Net UEncoder, then decodes it back to the original image using the U-Net UDecoder. """ - def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + def __init__( + self, + in_channels=1, + out_channels=1, + num_classes=2, + dropout=0.65, + ): super(WNet, self).__init__() - self.device = device - self.encoder = UNet(device, in_channels, num_classes, encoder=True) - self.decoder = UNet(device, num_classes, out_channels, encoder=False) + self.encoder = UNet( + in_channels, num_classes, encoder=True, dropout=dropout + ) + self.decoder = UNet( + num_classes, out_channels, encoder=False, dropout=dropout + ) def forward(self, x): """Forward pass of the W-Net model.""" enc = self.forward_encoder(x) - dec = self.forward_decoder(enc) - return enc, dec + return enc, self.forward_decoder(enc) def forward_encoder(self, x): """Forward pass of the encoder part of the W-Net model.""" @@ -60,39 +74,52 @@ class UNet(nn.Module): """Half of the W-Net model, based on the U-Net architecture.""" def __init__( - self, device, in_channels, out_channels, encoder=True, dropout=0.65 + self, + # device, + in_channels: int, + out_channels: int, + channels: List[int] = None, + encoder: bool = True, + dropout: float = 0.65, ): + if channels is None: + channels = [64, 128, 256, 512, 1024] + if len(channels) != 5: + raise ValueError( + "Channels must be a list of channels in the form: [64, 128, 256, 512, 1024]" + ) super(UNet, self).__init__() - self.device = device + # self.device = device + self.channels = channels self.max_pool = nn.MaxPool3d(2) - self.in_b = InBlock(device, in_channels, 64, dropout=dropout) - self.conv1 = Block(device, 64, 128, dropout=dropout) - self.conv2 = Block(device, 128, 256, dropout=dropout) - self.conv3 = Block(device, 256, 512, dropout=dropout) - self.bot = Block(device, 512, 1024, dropout=dropout) - self.deconv1 = Block(device, 1024, 512, dropout=dropout) + self.in_b = InBlock(in_channels, self.channels[0], dropout=dropout) + self.conv1 = Block(channels[0], self.channels[1], dropout=dropout) + self.conv2 = Block(channels[1], self.channels[2], dropout=dropout) + self.conv3 = Block(channels[2], self.channels[3], dropout=dropout) + self.bot = Block(channels[3], self.channels[4], dropout=dropout) + self.deconv1 = Block(channels[4], self.channels[3], dropout=dropout) self.conv_trans1 = nn.ConvTranspose3d( - 1024, 512, 2, stride=2, device=self.device + self.channels[4], self.channels[3], 2, stride=2 ) - self.deconv2 = Block(device, 512, 256, dropout=dropout) + self.deconv2 = Block(channels[3], self.channels[2], dropout=dropout) self.conv_trans2 = nn.ConvTranspose3d( - 512, 256, 2, stride=2, device=self.device + self.channels[3], self.channels[2], 2, stride=2 ) - self.deconv3 = Block(device, 256, 128, dropout=dropout) + self.deconv3 = Block(channels[2], self.channels[1], dropout=dropout) self.conv_trans3 = nn.ConvTranspose3d( - 256, 128, 2, stride=2, device=self.device + self.channels[2], self.channels[1], 2, stride=2 ) - self.out_b = OutBlock(device, 128, out_channels, dropout=dropout) + self.out_b = OutBlock(channels[1], out_channels, dropout=dropout) self.conv_trans_out = nn.ConvTranspose3d( - 128, 64, 2, stride=2, device=self.device + self.channels[1], self.channels[0], 2, stride=2 ) - self.sm = nn.Softmax(dim=1).to(device) + self.sm = nn.Softmax(dim=1) self.encoder = encoder def forward(self, x): """Forward pass of the U-Net model.""" - in_b = self.in_b(x.to(self.device)) + in_b = self.in_b(x) c1 = self.conv1(self.max_pool(in_b)) c2 = self.conv2(self.max_pool(c1)) c3 = self.conv3(self.max_pool(c2)) @@ -141,67 +168,67 @@ def forward(self, x): class InBlock(nn.Module): """Input block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, dropout=0.65): + def __init__(self, in_channels, out_channels, dropout=0.65): super(InBlock, self).__init__() - self.device = device + # self.device = device self.module = nn.Sequential( - nn.Conv3d(in_channels, out_channels, 3, padding=1, device=device), + nn.Conv3d(in_channels, out_channels, 3, padding=1), nn.ReLU(), nn.Dropout(p=dropout), - nn.BatchNorm3d(out_channels, device=device), - nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), + nn.BatchNorm3d(out_channels), + nn.Conv3d(out_channels, out_channels, 3, padding=1), nn.ReLU(), nn.Dropout(p=dropout), - nn.BatchNorm3d(out_channels, device=device), - ).to(device) + nn.BatchNorm3d(out_channels), + ) def forward(self, x): """Forward pass of the input block.""" - return self.module(x.to(self.device)) + return self.module(x) class Block(nn.Module): """Basic block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, dropout=0.65): + def __init__(self, in_channels, out_channels, dropout=0.65): super(Block, self).__init__() - self.device = device + # self.device = device self.module = nn.Sequential( - nn.Conv3d(in_channels, in_channels, 3, padding=1, device=device), - nn.Conv3d(in_channels, out_channels, 1, device=device), + nn.Conv3d(in_channels, in_channels, 3, padding=1), + nn.Conv3d(in_channels, out_channels, 1), nn.ReLU(), nn.Dropout(p=dropout), - nn.BatchNorm3d(out_channels, device=device), - nn.Conv3d(out_channels, out_channels, 3, padding=1, device=device), - nn.Conv3d(out_channels, out_channels, 1, device=device), + nn.BatchNorm3d(out_channels), + nn.Conv3d(out_channels, out_channels, 3, padding=1), + nn.Conv3d(out_channels, out_channels, 1), nn.ReLU(), nn.Dropout(p=dropout), - nn.BatchNorm3d(out_channels, device=device), - ).to(device) + nn.BatchNorm3d(out_channels), + ) def forward(self, x): """Forward pass of the basic block.""" - return self.module(x.to(self.device)) + return self.module(x) class OutBlock(nn.Module): """Output block of the U-Net architecture.""" - def __init__(self, device, in_channels, out_channels, dropout=0.65): + def __init__(self, in_channels, out_channels, dropout=0.65): super(OutBlock, self).__init__() - self.device = device + # self.device = device self.module = nn.Sequential( - nn.Conv3d(in_channels, 64, 3, padding=1, device=device), + nn.Conv3d(in_channels, 64, 3, padding=1), nn.ReLU(), nn.Dropout(p=dropout), - nn.BatchNorm3d(64, device=device), - nn.Conv3d(64, 64, 3, padding=1, device=device), + nn.BatchNorm3d(64), + nn.Conv3d(64, 64, 3, padding=1), nn.ReLU(), nn.Dropout(p=dropout), - nn.BatchNorm3d(64, device=device), - nn.Conv3d(64, out_channels, 1, device=device), - ).to(device) + nn.BatchNorm3d(64), + nn.Conv3d(64, out_channels, 1), + ) def forward(self, x): """Forward pass of the output block.""" - return self.module(x.to(self.device)) + return self.module(x) diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py index 938292c2..e0f92ff7 100644 --- a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -30,15 +30,17 @@ class SoftNCutsLoss(nn.Module): Args: data_shape (H, W, D): shape of the images as a tuple. - o_i (scalar): scale of the gaussian kernel of pixels brightness. - o_x (scalar): scale of the gaussian kernel of pixels spacial distance. + intensity_sigma (scalar): scale of the gaussian kernel of pixels brightness. + spatial_sigma (scalar): scale of the gaussian kernel of pixels spacial distance. radius (scalar): radius of pixels for which we compute the weights """ - def __init__(self, data_shape, device, o_i, o_x, radius=None): + def __init__( + self, data_shape, device, intensity_sigma, spatial_sigma, radius=None + ): super(SoftNCutsLoss, self).__init__() - self.o_i = o_i - self.o_x = o_x + self.intensity_sigma = intensity_sigma + self.spatial_sigma = spatial_sigma self.radius = radius self.H = data_shape[0] self.W = data_shape[1] @@ -52,73 +54,7 @@ def __init__(self, data_shape, device, o_i, o_x, radius=None): self.W, self.D, ) - - # self.distances, self.indexes = self.get_distances() - - """ - - # Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration - distances_H = torch.tensor(range(self.H)).expand(self.H, self.H) # (H, H) - distances_W = torch.tensor(range(self.W)).expand(self.W, self.W) # (W, W) - distances_D = torch.tensor(range(self.D)).expand(self.D, self.D) # (D, D) - - # Compute in cuda if possible - if torch.cuda.is_available(): - distances_H = distances_H.cuda() - distances_W = distances_W.cuda() - distances_D = distances_D.cuda() - - distances_H = torch.abs(torch.subtract(distances_H, distances_H.T)) # (H, H) - distances_W = torch.abs(torch.subtract(distances_W, distances_W.T)) # (W, W) - distances_D = torch.abs(torch.subtract(distances_D, distances_D.T)) # (D, D) - - distances_H = distances_H.view(self.H, 1, 1, self.H, 1, 1).expand( - self.H, self.W, self.D, self.H, self.W, self.D - ).to_sparse() # (H, 1, 1, H, 1, 1) -> (H, W, D, H, W, D) - distances_W = distances_W.view(1, self.W, 1, 1, self.W, 1).expand( - self.H, self.W, self.D, self.H, self.W, self.D - ).to_sparse() # (1, W, 1, 1, W, 1) -> (H, W, D, H, W, D) - distances_D = distances_D.view(1, 1, self.D, 1, 1, self.D).expand( - self.H, self.W, self.D, self.H, self.W, self.D - ).to_sparse() # (1, 1, D, 1, 1, D) -> (H, W, D, H, W, D) - - mask_H = torch.le(distances_H, self.radius).bool() # (H, W, D, H, W, D) - mask_W = torch.le(distances_W, self.radius).bool() # (H, W, D, H, W, D) - mask_D = torch.le(distances_D, self.radius).bool() # (H, W, D, H, W, D) - - distances_H = (distances_H * mask_H) # (H, W, D, H, W, D) - distances_W = (distances_W * mask_W) # (H, W, D, H, W, D) - distances_D = (distances_D * mask_D) # (H, W, D, H, W, D) - - mask_H =mask_H.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) - mask_W =mask_W.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) - mask_D =mask_D.flatten(0, 2).flatten(1, 3) # (H, W, D, H, W, D) - - distances_H = distances_H.pow(2) # (H, W, D, H, W, D) - distances_W = distances_W.pow(2) # (H, W, D, H, W, D) - distances_D = distances_D.pow(2) # (H, W, D, H, W, D) - - squared_distances = torch.add( - torch.add(distances_H, distances_W), - distances_D, - ) # (H, W, D, H, W, D) - - squared_distances = squared_distances.flatten(0, 2).flatten( - 1, 3 - ) # (H*W*D, H*W*D) - - # Mask to only keep the weights for the pixels in the radius - self.mask = torch.le(squared_distances, self.radius**2).bool() # (H*W*D, H*W*D) - - # Add all masks to get the final mask - self.mask = self.mask.logical_and(mask_H).logical_and(mask_W).logical_and(mask_D) # (H*W*D, H*W*D) - - W_X = torch.exp( - torch.neg(torch.div(squared_distances, self.o_x)) - ) # (H*W*D, H*W*D) - - self.W_X = torch.mul(W_X, self.mask) # (H*W*D, H*W*D) - """ + print(f"Radius set to {self.radius}") def forward(self, labels, inputs): """Forward pass of the Soft N-Cuts loss. @@ -130,8 +66,8 @@ def forward(self, labels, inputs): Returns: The Soft N-Cuts loss of shape (N,). """ - inputs.shape[0] - inputs.shape[1] + # inputs.shape[0] + # inputs.shape[1] K = labels.shape[1] labels.to(self.device) @@ -139,7 +75,9 @@ def forward(self, labels, inputs): loss = 0 - kernel = self.gaussian_kernel(self.radius, self.o_x).to(self.device) + kernel = self.gaussian_kernel(self.radius, self.spatial_sigma).to( + self.device + ) for k in range(K): # Compute the average pixel value for this class, and the difference from each pixel @@ -152,7 +90,9 @@ def forward(self, labels, inputs): diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1) # Weight the loss by the difference from the class average. - weights = torch.exp(diff.pow(2).mul(-1 / self.o_i**2)) + weights = torch.exp( + diff.pow(2).mul(-1 / self.intensity_sigma**2) + ) numerator = torch.sum( class_probs @@ -170,44 +110,6 @@ def forward(self, labels, inputs): return K - loss - """ - for k in range(K): - Ak = labels[:, k, :, :, :] # (N, H, W, D) - flatted_Ak = Ak.view(N, -1) # (N, H*W*D) - - # Compute the numerator of the Soft N-Cuts loss for k - flatted_Ak_unsqueeze = flatted_Ak.unsqueeze(1) # (N, 1, H*W*D) - transposed_Ak = torch.transpose(flatted_Ak_unsqueeze, 1, 2) # (N, H*W*D, 1) - probs = torch.bmm(transposed_Ak, flatted_Ak_unsqueeze) # (N, H*W*D, H*W*D) - probs_unsqueeze_expanded = probs.unsqueeze(1) # (N, 1, H*W*D, H*W*D) - numerator_elements = torch.mul( - probs_unsqueeze_expanded, weights - ) # (N, C, H*W*D, H*W*D) - numerator = torch.sum(numerator_elements, dim=(2, 3)) # (N, C) - - # Compute the denominator of the Soft N-Cuts loss for k - expanded_flatted_Ak = flatted_Ak.expand( - -1, self.H * self.W * self.D - ) # (N, H*W*D, H*W*D) - e_f_Ak_unsqueeze_expanded = expanded_flatted_Ak.unsqueeze( - 1 - ) # (N, 1, H*W*D, H*W*D) - denominator_elements = torch.mul( - e_f_Ak_unsqueeze_expanded, weights - ) # (N, C, H*W*D, H*W*D) - denominator = torch.sum(denominator_elements, dim=(2, 3)) # (N, C) - - # Compute the Soft N-Cuts loss for k - division = torch.div(numerator, torch.add(denominator, 1e-8)) # (N, C) - loss = torch.sum(division, dim=1) # (N,) - losses.append(loss) - - loss = torch.sum(torch.stack(losses, dim=0), dim=0) # (N,) - - return torch.add(torch.neg(loss), K) - """ - return None - def gaussian_kernel(self, radius, sigma): """Computes the Gaussian kernel. @@ -229,12 +131,10 @@ def gaussian_kernel(self, radius, sigma): ) kernel = norm.pdf(dist) / norm.pdf(0) kernel = torch.from_numpy(kernel.astype(np.float32)) - kernel = kernel.view( + return kernel.view( (1, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]) ) - return kernel - def get_distances(self): """Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration. @@ -283,7 +183,9 @@ def get_distances(self): distance = np.linalg.norm(i - j) if distance > self.radius: continue - distance = math.exp(-(distance**2) / (self.o_x**2)) + distance = math.exp( + -(distance**2) / (self.spatial_sigma**2) + ) if jTuple not in distances: distances[iTuple][jTuple] = distance @@ -300,35 +202,6 @@ def get_weights(self, inputs): list: List of the weights dict for each image in the batch. """ - """ - weights = [] - for n in range(inputs.shape[0]): - weightsChannel = [] - for c in range(inputs.shape[1]): - weightsImage = dict() - for i in self.indexes: - iTuple = (i[0], i[1], i[2]) - weightsImage[iTuple] = dict() - for j in self.indexes: - jTuple = (j[0], j[1], j[2]) - if iTuple in self.distances and jTuple in self.distances[i]: - brightness = ( - inputs[n][c][i[0]][i[1]][i[2]] - - inputs[n][c][j[0]][j[1]][j[2]] - ) ** 2 - brightness = math.exp(-brightness / self.o_i**2) - weightsImage[iTuple][jTuple] = ( - self.distances[iTuple][jTuple] * brightness - ) - - weightsChannel.append(weightsImage) - - weights.append(weightsChannel) - - return weights - - """ - # Compute the brightness distance of the pixels flatted_inputs = inputs.view( inputs.shape[0], inputs.shape[1], -1 @@ -340,7 +213,7 @@ def get_weights(self, inputs): squared_I_diff = torch.pow(masked_I_diff, 2) # (N, C, H*W*D, H*W*D) W_I = torch.exp( - torch.neg(torch.div(squared_I_diff, self.o_i)) + torch.neg(torch.div(squared_I_diff, self.intensity_sigma)) ) # (N, C, H*W*D, H*W*D) W_I = torch.mul(W_I, self.mask) # (N, C, H*W*D, H*W*D) diff --git a/napari_cellseg3d/code_models/models/wnet/train_wnet.py b/napari_cellseg3d/code_models/models/wnet/train_wnet.py new file mode 100644 index 00000000..61d8959a --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/train_wnet.py @@ -0,0 +1,1008 @@ +""" +This file contains the code to train the WNet model. +""" +# import napari +import glob +import time +from pathlib import Path +from typing import Union +from warnings import warn + +import numpy as np +import tifffile as tiff +import torch +import torch.nn as nn + +# MONAI +from monai.data import ( + CacheDataset, + DataLoader, + PatchDataset, + pad_list_data_collate, +) +from monai.data.meta_obj import set_track_meta +from monai.metrics import DiceMetric +from monai.transforms import ( + AsDiscrete, + Compose, + EnsureChannelFirst, + EnsureChannelFirstd, + EnsureTyped, + LoadImaged, + Orientationd, + RandFlipd, + RandRotate90d, + RandShiftIntensityd, + RandSpatialCropSamplesd, + ScaleIntensityRanged, + SpatialPadd, + ToTensor, +) +from monai.utils.misc import set_determinism + +# local +from napari_cellseg3d.code_models.models.wnet.model import WNet +from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss +from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import dice_coeff, get_padding_dim + +try: + import wandb + + WANDB_INSTALLED = True +except ImportError: + warn( + "wandb not installed, wandb config will not be taken into account", + stacklevel=1, + ) + WANDB_INSTALLED = False + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" + + +########################## +# Utils functions # +########################## + + +def create_dataset_dict(volume_directory, label_directory): + """Creates data dictionary for MONAI transforms and training.""" + images_filepaths = sorted( + [str(file) for file in Path(volume_directory).glob("*.tif")] + ) + + labels_filepaths = sorted( + [str(file) for file in Path(label_directory).glob("*.tif")] + ) + if len(images_filepaths) == 0 or len(labels_filepaths) == 0: + raise ValueError( + f"Data folders are empty \n{volume_directory} \n{label_directory}" + ) + + logger.info("Images :") + for file in images_filepaths: + logger.info(Path(file).stem) + logger.info("*" * 10) + logger.info("Labels :") + for file in labels_filepaths: + logger.info(Path(file).stem) + try: + data_dicts = [ + {"image": image_name, "label": label_name} + for image_name, label_name in zip( + images_filepaths, labels_filepaths + ) + ] + except ValueError as e: + raise ValueError( + f"Number of images and labels does not match : \n{volume_directory} \n{label_directory}" + ) from e + # print(f"Loaded eval image: {data_dicts}") + return data_dicts + + +def create_dataset_dict_no_labs(volume_directory): + """Creates unsupervised data dictionary for MONAI transforms and training.""" + images_filepaths = sorted(glob.glob(str(Path(volume_directory) / "*.tif"))) + if len(images_filepaths) == 0: + raise ValueError(f"Data folder {volume_directory} is empty") + + logger.info("Images :") + for file in images_filepaths: + logger.info(Path(file).stem) + logger.info("*" * 10) + + return [{"image": image_name} for image_name in images_filepaths] + + +def remap_image( + image: Union[np.ndarray, torch.Tensor], new_max=100, new_min=0 +): + """Normalizes a numpy array or Tensor using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image * (new_max - new_min) + new_min + # image = set_quantile_to_value(image) + return image.reshape(shape) + + +################################ +# Config & WANDB # +################################ + + +class Config: + def __init__(self): + # WNet + self.in_channels = 1 + self.out_channels = 1 + self.num_classes = 2 + self.dropout = 0.65 + self.use_clipping = False + self.clipping = 1 + + self.lr = 1e-6 + self.scheduler = "None" # "CosineAnnealingLR" # "ReduceLROnPlateau" + self.weight_decay = 0.01 # None + + self.intensity_sigma = 1 + self.spatial_sigma = 4 + self.radius = 2 # yields to a radius depending on the data shape + + self.n_cuts_weight = 0.5 + self.reconstruction_loss = "MSE" # "BCE" + self.rec_loss_weight = 0.5 / 100 + + self.num_epochs = 100 + self.val_interval = 5 + self.batch_size = 2 + self.num_workers = 4 + + # CRF + self.sa = 50 # 10 + self.sb = 20 + self.sg = 1 + self.w1 = 50 # 10 + self.w2 = 20 + self.n_iter = 5 + + # Data + self.train_volume_directory = "./../dataset/VIP_full" + self.eval_volume_directory = "./../dataset/VIP_cropped/eval/" + self.normalize_input = True + self.normalizing_function = remap_image # normalize_quantile + self.use_patch = False + self.patch_size = (64, 64, 64) + self.num_patches = 30 + self.eval_num_patches = 20 + self.do_augmentation = True + self.parallel = False + + self.save_model = True + self.save_model_path = ( + r"./../results/new_model/wnet_new_model_all_data_3class.pth" + ) + # self.save_losses_path = ( + # r"./../results/new_model/wnet_new_model_all_data_3class.pkl" + # ) + self.save_every = 5 + self.weights_path = None + + +c = Config() +############### +# Scheduler config +############### +schedulers = { + "ReduceLROnPlateau": { + "factor": 0.5, + "patience": 50, + }, + "CosineAnnealingLR": { + "T_max": 25000, + "eta_min": 1e-8, + }, + "CosineAnnealingWarmRestarts": { + "T_0": 50000, + "eta_min": 1e-8, + "T_mult": 1, + }, + "CyclicLR": { + "base_lr": 2e-7, + "max_lr": 2e-4, + "step_size_up": 250, + "mode": "triangular", + }, +} + +############### +# WANDB_CONFIG +############### +WANDB_MODE = "disabled" +# WANDB_MODE = "online" + +WANDB_CONFIG = { + # data setting + "num_workers": c.num_workers, + "normalize": c.normalize_input, + "use_patch": c.use_patch, + "patch_size": c.patch_size, + "num_patches": c.num_patches, + "eval_num_patches": c.eval_num_patches, + "do_augmentation": c.do_augmentation, + "model_save_path": c.save_model_path, + # train setting + "batch_size": c.batch_size, + "learning_rate": c.lr, + "weight_decay": c.weight_decay, + "scheduler": { + "name": c.scheduler, + "ReduceLROnPlateau_config": { + "factor": schedulers["ReduceLROnPlateau"]["factor"], + "patience": schedulers["ReduceLROnPlateau"]["patience"], + }, + "CosineAnnealingLR_config": { + "T_max": schedulers["CosineAnnealingLR"]["T_max"], + "eta_min": schedulers["CosineAnnealingLR"]["eta_min"], + }, + "CosineAnnealingWarmRestarts_config": { + "T_0": schedulers["CosineAnnealingWarmRestarts"]["T_0"], + "eta_min": schedulers["CosineAnnealingWarmRestarts"]["eta_min"], + "T_mult": schedulers["CosineAnnealingWarmRestarts"]["T_mult"], + }, + "CyclicLR_config": { + "base_lr": schedulers["CyclicLR"]["base_lr"], + "max_lr": schedulers["CyclicLR"]["max_lr"], + "step_size_up": schedulers["CyclicLR"]["step_size_up"], + "mode": schedulers["CyclicLR"]["mode"], + }, + }, + "max_epochs": c.num_epochs, + "save_every": c.save_every, + "val_interval": c.val_interval, + # loss + "reconstruction_loss": c.reconstruction_loss, + "loss weights": { + "n_cuts_weight": c.n_cuts_weight, + "rec_loss_weight": c.rec_loss_weight, + }, + "loss_params": { + "intensity_sigma": c.intensity_sigma, + "spatial_sigma": c.spatial_sigma, + "radius": c.radius, + }, + # model + "model_type": "wnet", + "model_params": { + "in_channels": c.in_channels, + "out_channels": c.out_channels, + "num_classes": c.num_classes, + "dropout": c.dropout, + "use_clipping": c.use_clipping, + "clipping_value": c.clipping, + }, + # CRF + "crf_params": { + "sa": c.sa, + "sb": c.sb, + "sg": c.sg, + "w1": c.w1, + "w2": c.w2, + "n_iter": c.n_iter, + }, +} + + +def train(weights_path=None, train_config=None): + if train_config is None: + config = Config() + ############## + # disable metadata tracking + set_track_meta(False) + ############## + if WANDB_INSTALLED: + wandb.init( + config=WANDB_CONFIG, project="WNet-benchmark", mode=WANDB_MODE + ) + + set_determinism(seed=34936339) # use default seed from NP_MAX + torch.use_deterministic_algorithms(True, warn_only=True) + + config = train_config + normalize_function = config.normalizing_function + CUDA = torch.cuda.is_available() + device = torch.device("cuda" if CUDA else "cpu") + + print(f"Using device: {device}") + + print("Config:") + [print(a) for a in config.__dict__.items()] + + print("Initializing training...") + print("Getting the data") + + if config.use_patch: + (data_shape, dataset) = get_patch_dataset(config) + else: + (data_shape, dataset) = get_dataset(config) + transform = Compose( + [ + ToTensor(), + EnsureChannelFirst(channel_dim=0), + ] + ) + dataset = [transform(im) for im in dataset] + for data in dataset: + print(f"data shape: {data.shape}") + break + + dataloader = DataLoader( + dataset, + batch_size=config.batch_size, + shuffle=True, + num_workers=config.num_workers, + collate_fn=pad_list_data_collate, + ) + + if config.eval_volume_directory is not None: + eval_dataset = get_patch_eval_dataset(config) + + eval_dataloader = DataLoader( + eval_dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + collate_fn=pad_list_data_collate, + ) + + dice_metric = DiceMetric( + include_background=False, reduction="mean", get_not_nans=False + ) + ################################################### + # Training the model # + ################################################### + print("Initializing the model:") + + print("- getting the model") + # Initialize the model + model = WNet( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_classes=config.num_classes, + dropout=config.dropout, + ) + model = ( + nn.DataParallel(model).cuda() if CUDA and config.parallel else model + ) + model.to(device) + + if config.use_clipping: + for p in model.parameters(): + p.register_hook( + lambda grad: torch.clamp( + grad, min=-config.clipping, max=config.clipping + ) + ) + + if WANDB_INSTALLED: + wandb.watch(model, log_freq=100) + + if weights_path is not None: + model.load_state_dict(torch.load(weights_path, map_location=device)) + + print("- getting the optimizers") + # Initialize the optimizers + if config.weight_decay is not None: + decay = config.weight_decay + optimizer = torch.optim.Adam( + model.parameters(), lr=config.lr, weight_decay=decay + ) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) + + print("- getting the loss functions") + # Initialize the Ncuts loss function + criterionE = SoftNCutsLoss( + data_shape=data_shape, + device=device, + intensity_sigma=config.intensity_sigma, + spatial_sigma=config.spatial_sigma, + radius=config.radius, + ) + + if config.reconstruction_loss == "MSE": + criterionW = nn.MSELoss() + elif config.reconstruction_loss == "BCE": + criterionW = nn.BCELoss() + else: + raise ValueError( + f"Unknown reconstruction loss : {config.reconstruction_loss} not supported" + ) + + print("- getting the learning rate schedulers") + # Initialize the learning rate schedulers + scheduler = get_scheduler(config, optimizer) + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + # optimizer, mode="min", factor=0.5, patience=10, verbose=True + # ) + model.train() + + print("Ready") + print("Training the model") + print("*" * 50) + + startTime = time.time() + ncuts_losses = [] + rec_losses = [] + total_losses = [] + best_dice = -1 + best_dice_epoch = -1 + + # Train the model + for epoch in range(config.num_epochs): + print(f"Epoch {epoch + 1} of {config.num_epochs}") + + epoch_ncuts_loss = 0 + epoch_rec_loss = 0 + epoch_loss = 0 + + for _i, batch in enumerate(dataloader): + # raise NotImplementedError("testing") + if config.use_patch: + image = batch["image"].to(device) + else: + image = batch.to(device) + if config.batch_size == 1: + image = image.unsqueeze(0) + else: + image = image.unsqueeze(0) + image = torch.swapaxes(image, 0, 1) + + # Forward pass + enc = model.forward_encoder(image) + # out = model.forward(image) + + # Compute the Ncuts loss + Ncuts = criterionE(enc, image) + epoch_ncuts_loss += Ncuts.item() + if WANDB_INSTALLED: + wandb.log({"Ncuts loss": Ncuts.item()}) + + # Forward pass + enc, dec = model(image) + + # Compute the reconstruction loss + if isinstance(criterionW, nn.MSELoss): + reconstruction_loss = criterionW(dec, image) + elif isinstance(criterionW, nn.BCELoss): + reconstruction_loss = criterionW( + torch.sigmoid(dec), + remap_image(image, new_max=1), + ) + + epoch_rec_loss += reconstruction_loss.item() + if WANDB_INSTALLED: + wandb.log({"Reconstruction loss": reconstruction_loss.item()}) + + # Backward pass for the reconstruction loss + optimizer.zero_grad() + alpha = config.n_cuts_weight + beta = config.rec_loss_weight + + loss = alpha * Ncuts + beta * reconstruction_loss + epoch_loss += loss.item() + if WANDB_INSTALLED: + wandb.log({"Sum of losses": loss.item()}) + loss.backward(loss) + optimizer.step() + + if config.scheduler == "CosineAnnealingWarmRestarts": + scheduler.step(epoch + _i / len(dataloader)) + if ( + config.scheduler == "CosineAnnealingLR" + or config.scheduler == "CyclicLR" + ): + scheduler.step() + + ncuts_losses.append(epoch_ncuts_loss / len(dataloader)) + rec_losses.append(epoch_rec_loss / len(dataloader)) + total_losses.append(epoch_loss / len(dataloader)) + + if WANDB_INSTALLED: + wandb.log({"Ncuts loss_epoch": ncuts_losses[-1]}) + wandb.log({"Reconstruction loss_epoch": rec_losses[-1]}) + wandb.log({"Sum of losses_epoch": total_losses[-1]}) + # wandb.log({"epoch": epoch}) + # wandb.log({"learning_rate model": optimizerW.param_groups[0]["lr"]}) + # wandb.log({"learning_rate encoder": optimizerE.param_groups[0]["lr"]}) + wandb.log({"learning_rate model": optimizer.param_groups[0]["lr"]}) + + print("Ncuts loss: ", ncuts_losses[-1]) + if epoch > 0: + print( + "Ncuts loss difference: ", + ncuts_losses[-1] - ncuts_losses[-2], + ) + print("Reconstruction loss: ", rec_losses[-1]) + if epoch > 0: + print( + "Reconstruction loss difference: ", + rec_losses[-1] - rec_losses[-2], + ) + print("Sum of losses: ", total_losses[-1]) + if epoch > 0: + print( + "Sum of losses difference: ", + total_losses[-1] - total_losses[-2], + ) + + # Update the learning rate + if config.scheduler == "ReduceLROnPlateau": + # schedulerE.step(epoch_ncuts_loss) + # schedulerW.step(epoch_rec_loss) + scheduler.step(epoch_rec_loss) + if ( + config.eval_volume_directory is not None + and (epoch + 1) % config.val_interval == 0 + ): + model.eval() + print("Validating...") + with torch.no_grad(): + for _k, val_data in enumerate(eval_dataloader): + val_inputs, val_labels = ( + val_data["image"].to(device), + val_data["label"].to(device), + ) + + # normalize val_inputs across channels + if config.normalize_input: + for i in range(val_inputs.shape[0]): + for j in range(val_inputs.shape[1]): + val_inputs[i][j] = normalize_function( + val_inputs[i][j] + ) + + val_outputs = model.forward_encoder(val_inputs) + val_outputs = AsDiscrete(threshold=0.5)(val_outputs) + + # compute metric for current iteration + for channel in range(val_outputs.shape[1]): + max_dice_channel = torch.argmax( + torch.Tensor( + [ + dice_coeff( + y_pred=val_outputs[ + :, + channel : (channel + 1), + :, + :, + :, + ], + y_true=val_labels, + ) + ] + ) + ) + + dice_metric( + y_pred=val_outputs[ + :, + max_dice_channel : (max_dice_channel + 1), + :, + :, + :, + ], + y=val_labels, + ) + # if plot_val_input: # only once + # logged_image = val_inputs.detach().cpu().numpy() + # logged_image = np.swapaxes(logged_image, 2, 4) + # logged_image = logged_image[0, :, 32, :, :] + # images = wandb.Image( + # logged_image, caption="Validation input" + # ) + # + # wandb.log({"val/input": images}) + # plot_val_input = False + + # if k == 2 and (30 <= epoch <= 50 or epoch % 100 == 0): + # logged_image = val_outputs.detach().cpu().numpy() + # logged_image = np.swapaxes(logged_image, 2, 4) + # logged_image = logged_image[ + # 0, max_dice_channel, 32, :, : + # ] + # images = wandb.Image( + # logged_image, caption="Validation output" + # ) + # + # wandb.log({"val/output": images}) + # dice_metric(y_pred=val_outputs[:, 2:, :,:,:], y=val_labels) + # dice_metric(y_pred=val_outputs[:, 1:, :, :, :], y=val_labels) + + # import napari + # view = napari.Viewer() + # view.add_image(val_inputs.cpu().numpy(), name="input") + # view.add_image(val_labels.cpu().numpy(), name="label") + # vis_out = np.array( + # [i.detach().cpu().numpy() for i in val_outputs], + # dtype=np.float32, + # ) + # crf_out = np.array( + # [i.detach().cpu().numpy() for i in crf_outputs], + # dtype=np.float32, + # ) + # view.add_image(vis_out, name="output") + # view.add_image(crf_out, name="crf_output") + # napari.run() + + # aggregate the final mean dice result + metric = dice_metric.aggregate().item() + print("Validation Dice score: ", metric) + if best_dice < metric < 2: + best_dice = metric + best_dice_epoch = epoch + 1 + if config.save_model: + save_best_path = Path(config.save_model_path).parents[ + 0 + ] + save_best_path.mkdir(parents=True, exist_ok=True) + save_best_name = Path(config.save_model_path).stem + save_path = ( + str(save_best_path / save_best_name) + + "_best_metric.pth" + ) + print(f"Saving new best model to {save_path}") + torch.save(model.state_dict(), save_path) + + if WANDB_INSTALLED: + # log validation dice score for each validation round + wandb.log({"val/dice_metric": metric}) + + # reset the status for next validation round + dice_metric.reset() + + print( + "ETA: ", + (time.time() - startTime) + * (config.num_epochs / (epoch + 1) - 1) + / 60, + "minutes", + ) + print("-" * 20) + + # Save the model + if config.save_model and epoch % config.save_every == 0: + torch.save(model.state_dict(), config.save_model_path) + # with open(config.save_losses_path, "wb") as f: + # pickle.dump((ncuts_losses, rec_losses), f) + + print("Training finished") + print(f"Best dice metric : {best_dice}") + if WANDB_INSTALLED and config.eval_volume_directory is not None: + wandb.log( + { + "best_dice_metric": best_dice, + "best_metric_epoch": best_dice_epoch, + } + ) + print("*" * 50) + + # Save the model + if config.save_model: + print("Saving the model to: ", config.save_model_path) + torch.save(model.state_dict(), config.save_model_path) + # with open(config.save_losses_path, "wb") as f: + # pickle.dump((ncuts_losses, rec_losses), f) + if WANDB_INSTALLED: + model_artifact = wandb.Artifact( + "WNet", + type="model", + description="WNet benchmark", + metadata=dict(WANDB_CONFIG), + ) + model_artifact.add_file(config.save_model_path) + wandb.log_artifact(model_artifact) + + return ncuts_losses, rec_losses, model + + +def get_dataset(config): + """Creates a Dataset from the original data using the tifffile library + + Args: + config (Config): The configuration object + + Returns: + (tuple): A tuple containing the shape of the data and the dataset + """ + train_files = create_dataset_dict_no_labs( + volume_directory=config.train_volume_directory + ) + train_files = [d.get("image") for d in train_files] + volumes = tiff.imread(train_files).astype(np.float32) + volume_shape = volumes.shape + + if config.normalize_input: + volumes = np.array( + [ + # mad_normalization(volume) + config.normalizing_function(volume) + for volume in volumes + ] + ) + # mean = volumes.mean(axis=0) + # std = volumes.std(axis=0) + # volumes = (volumes - mean) / std + # print("NORMALIZED VOLUMES") + # print(volumes.shape) + # [print("MIN MAX", volume.flatten().min(), volume.flatten().max()) for volume in volumes] + # print(volumes.mean(axis=0), volumes.std(axis=0)) + + dataset = CacheDataset(data=volumes) + + return (volume_shape, dataset) + + # train_files = create_dataset_dict_no_labs( + # volume_directory=config.train_volume_directory + # ) + # train_files = [d.get("image") for d in train_files] + # volumes = [] + # for file in train_files: + # image = tiff.imread(file).astype(np.float32) + # image = np.expand_dims(image, axis=0) # add channel dimension + # volumes.append(image) + # # volumes = tiff.imread(train_files).astype(np.float32) + # volume_shape = volumes[0].shape + # # print(volume_shape) + # + # if config.do_augmentation: + # augmentation = Compose( + # [ + # ScaleIntensityRange( + # a_min=0, + # a_max=2000, + # b_min=0.0, + # b_max=1.0, + # clip=True, + # ), + # RandShiftIntensity(offsets=0.1, prob=0.5), + # RandFlip(spatial_axis=[1], prob=0.5), + # RandFlip(spatial_axis=[2], prob=0.5), + # RandRotate90(prob=0.1, max_k=3), + # ] + # ) + # else: + # augmentation = None + # + # dataset = CacheDataset(data=np.array(volumes), transform=augmentation) + # + # return (volume_shape, dataset) + + +def get_patch_dataset(config): + """Creates a Dataset from the original data using the tifffile library + + Args: + config (Config): The configuration object + + Returns: + (tuple): A tuple containing the shape of the data and the dataset + """ + + train_files = create_dataset_dict_no_labs( + volume_directory=config.train_volume_directory + ) + + patch_func = Compose( + [ + LoadImaged(keys=["image"], image_only=True), + EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"), + RandSpatialCropSamplesd( + keys=["image"], + roi_size=( + config.patch_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=config.num_patches, + ), + Orientationd(keys=["image"], axcodes="PLI"), + SpatialPadd( + keys=["image"], + spatial_size=(get_padding_dim(config.patch_size)), + ), + EnsureTyped(keys=["image"]), + ] + ) + + train_transforms = Compose( + [ + ScaleIntensityRanged( + keys=["image"], + a_min=0, + a_max=2000, + b_min=0.0, + b_max=1.0, + clip=True, + ), + RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[1], prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[2], prob=0.5), + RandRotate90d(keys=["image"], prob=0.1, max_k=3), + EnsureTyped(keys=["image"]), + ] + ) + + dataset = PatchDataset( + data=train_files, + samples_per_image=config.num_patches, + patch_func=patch_func, + transform=train_transforms, + ) + + return config.patch_size, dataset + + +def get_patch_eval_dataset(config): + eval_files = create_dataset_dict( + volume_directory=config.eval_volume_directory + "/vol", + label_directory=config.eval_volume_directory + "/lab", + ) + + patch_func = Compose( + [ + LoadImaged(keys=["image", "label"], image_only=True), + EnsureChannelFirstd( + keys=["image", "label"], channel_dim="no_channel" + ), + # NormalizeIntensityd(keys=["image"]) if config.normalize_input else lambda x: x, + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + config.patch_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=config.eval_num_patches, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=(get_padding_dim(config.patch_size)), + ), + EnsureTyped(keys=["image", "label"]), + ] + ) + + eval_transforms = Compose( + [ + EnsureTyped(keys=["image", "label"]), + ] + ) + + return PatchDataset( + data=eval_files, + samples_per_image=config.eval_num_patches, + patch_func=patch_func, + transform=eval_transforms, + ) + + +def get_dataset_monai(config): + """Creates a Dataset applying some transforms/augmentation on the data using the MONAI library + + Args: + config (Config): The configuration object + + Returns: + (tuple): A tuple containing the shape of the data and the dataset + """ + train_files = create_dataset_dict_no_labs( + volume_directory=config.train_volume_directory + ) + # print(train_files) + # print(len(train_files)) + # print(train_files[0]) + first_volume = LoadImaged(keys=["image"])(train_files[0]) + first_volume_shape = first_volume["image"].shape + + # Transforms to be applied to each volume + load_single_images = Compose( + [ + LoadImaged(keys=["image"]), + EnsureChannelFirstd(keys=["image"]), + Orientationd(keys=["image"], axcodes="PLI"), + SpatialPadd( + keys=["image"], + spatial_size=(get_padding_dim(first_volume_shape)), + ), + EnsureTyped(keys=["image"]), + ] + ) + + if config.do_augmentation: + train_transforms = Compose( + [ + ScaleIntensityRanged( + keys=["image"], + a_min=0, + a_max=2000, + b_min=0.0, + b_max=1.0, + clip=True, + ), + RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[1], prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[2], prob=0.5), + RandRotate90d(keys=["image"], prob=0.1, max_k=3), + EnsureTyped(keys=["image"]), + ] + ) + else: + train_transforms = EnsureTyped(keys=["image"]) + + # Create the dataset + dataset = CacheDataset( + data=train_files, + transform=Compose(load_single_images, train_transforms), + ) + + return first_volume_shape, dataset + + +def get_scheduler(config, optimizer, verbose=False): + scheduler_name = config.scheduler + if scheduler_name == "None": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=100, + eta_min=config.lr - 1e-6, + verbose=verbose, + ) + + elif scheduler_name == "ReduceLROnPlateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=schedulers["ReduceLROnPlateau"]["factor"], + patience=schedulers["ReduceLROnPlateau"]["patience"], + verbose=verbose, + ) + elif scheduler_name == "CosineAnnealingLR": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=schedulers["CosineAnnealingLR"]["T_max"], + eta_min=schedulers["CosineAnnealingLR"]["eta_min"], + verbose=verbose, + ) + elif scheduler_name == "CosineAnnealingWarmRestarts": + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=schedulers["CosineAnnealingWarmRestarts"]["T_0"], + eta_min=schedulers["CosineAnnealingWarmRestarts"]["eta_min"], + T_mult=schedulers["CosineAnnealingWarmRestarts"]["T_mult"], + verbose=verbose, + ) + elif scheduler_name == "CyclicLR": + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=schedulers["CyclicLR"]["base_lr"], + max_lr=schedulers["CyclicLR"]["max_lr"], + step_size_up=schedulers["CyclicLR"]["step_size_up"], + mode=schedulers["CyclicLR"]["mode"], + cycle_momentum=False, + ) + else: + raise ValueError(f"Scheduler {scheduler_name} not provided") + return scheduler + + +if __name__ == "__main__": + weights_location = str( + # Path(__file__).resolve().parent / "../weights/wnet.pth" + # "../wnet_SUM_MSE_DAPI_rad2_best_metric.pth" + ) + train( + # weights_location + ) diff --git a/notebooks/train_wnet.ipynb b/notebooks/train_wnet.ipynb new file mode 100644 index 00000000..4fb6c0f4 --- /dev/null +++ b/notebooks/train_wnet.ipynb @@ -0,0 +1,267 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:14.017741900Z", + "start_time": "2023-07-10T08:00:14.007742500Z" + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.code_models.models.wnet.train_wnet import Config, train\n", + "from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "config = Config()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:14.382675700Z", + "start_time": "2023-07-10T08:00:14.354604Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic config :" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "config.num_epochs = 100\n", + "config.val_interval = 1 # performs validation with test dataset every n epochs\n", + "config.batch_size = 1" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:15.040773600Z", + "start_time": "2023-07-10T08:00:15.020804400Z" + } + } + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Image directories :\n", + "- `train_volume_directory` : The path to the folder containing the 3D .tif files on which to train\n", + "- `eval_volume_directory` : If available, the path to the validation set to compute Dice metric on; labels should be in a \"lab\" folder, volumes in \"vol\" at the specified path. Images and labels should match when sorted alphabetically" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "config.train_volume_directory = str(Path.home() / \"Desktop/Code/WNet-benchmark/dataset/VIP_small\")\n", + "config.eval_volume_directory = None\n", + "\n", + "config.save_model_path = \"./results\"" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:15.810624400Z", + "start_time": "2023-07-10T08:00:15.791682400Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### Advanced config\n", + "Note : more parameters can be found in the config.py file, depending on your needs" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "config.in_channels = 1\n", + "config.out_channels = 1\n", + "config.num_classes = 2\n", + "config.dropout = 0.65\n", + "\n", + "config.lr = 1e-6 # learning rate\n", + "config.scheduler = \"None\" # \"CosineAnnealingLR\" # \"ReduceLROnPlateau\" # can be further tweaked in config\n", + "config.weight_decay = 0.01 # None" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:16.455904800Z", + "start_time": "2023-07-10T08:00:16.445901900Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n", + "Config:\n", + "('in_channels', 1)\n", + "('out_channels', 1)\n", + "('num_classes', 2)\n", + "('dropout', 0.65)\n", + "('use_clipping', False)\n", + "('clipping', 1)\n", + "('lr', 1e-06)\n", + "('scheduler', 'None')\n", + "('weight_decay', 0.01)\n", + "('intensity_sigma', 1)\n", + "('spatial_sigma', 4)\n", + "('radius', 2)\n", + "('n_cuts_weight', 0.5)\n", + "('reconstruction_loss', 'MSE')\n", + "('rec_loss_weight', 0.005)\n", + "('num_epochs', 100)\n", + "('val_interval', 1)\n", + "('batch_size', 1)\n", + "('num_workers', 4)\n", + "('sa', 50)\n", + "('sb', 20)\n", + "('sg', 1)\n", + "('w1', 50)\n", + "('w2', 20)\n", + "('n_iter', 5)\n", + "('train_volume_directory', 'C:\\\\Users\\\\Cyril\\\\Desktop\\\\Code\\\\WNet-benchmark\\\\dataset\\\\VIP_small')\n", + "('eval_volume_directory', None)\n", + "('normalize_input', True)\n", + "('normalizing_function', )\n", + "('use_patch', False)\n", + "('patch_size', (64, 64, 64))\n", + "('num_patches', 30)\n", + "('eval_num_patches', 20)\n", + "('do_augmentation', True)\n", + "('parallel', False)\n", + "('save_model', True)\n", + "('save_model_path', './../results/new_model/wnet_new_model_all_data_3class.pth')\n", + "('save_every', 5)\n", + "('weights_path', None)\n", + "Initializing training...\n", + "Getting the data\n", + "2023-07-10 10:00:17,137 - Images :\n", + "2023-07-10 10:00:17,137 - 1\n", + "2023-07-10 10:00:17,137 - 2\n", + "2023-07-10 10:00:17,137 - **********\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████| 2/2 [00:00 Date: Mon, 10 Jul 2023 10:33:56 +0200 Subject: [PATCH 573/577] Remove dask --- pyproject.toml | 2 +- requirements.txt | 1 - setup.cfg | 3 --- tox.ini | 2 -- 4 files changed, 1 insertion(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e39a7522..0e82ed6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "napari[all]>=0.4.14", "QtPy", "opencv-python>=4.5.5", - "dask-image>=0.6.0", +# "dask-image>=0.6.0", "scikit-image>=0.19.2", "matplotlib>=3.4.1", "tifffile>=2022.2.9", diff --git a/requirements.txt b/requirements.txt index 3ca0e56d..ada03ae4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,6 @@ QtPy opencv-python>=4.5.5 pre-commit pyclesperanto-prototype>=0.22.0 -dask-image>=0.6.0 matplotlib>=3.4.1 ruff tifffile>=2022.2.9 diff --git a/setup.cfg b/setup.cfg index 8ee82f96..7a72482a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,14 +38,11 @@ package_dir = =. # add your package requirements here -# the long list after monai is due to monai optional requirements... Not sure how to know in advance which readers it wil use -# FIXME remove dask install_requires = numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 - dask-image>=0.6.0 scikit-image>=0.19.2 matplotlib>=3.4.1 tifffile>=2022.2.9 diff --git a/tox.ini b/tox.ini index 0605fc8c..1b9b5e22 100644 --- a/tox.ini +++ b/tox.ini @@ -29,8 +29,6 @@ passenv = deps = pytest # https://docs.pytest.org/en/latest/contents.html pytest-cov # https://pytest-cov.readthedocs.io/en/latest/ -; dask-image -; # you can remove these if you don't use them napari PyQt5 magicgui From 02fdcf7ecc499845101adb98aae7c1a8050677fa Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 10 Jul 2023 10:34:32 +0200 Subject: [PATCH 574/577] WNet model docs --- docs/res/guides/inference_module_guide.rst | 35 ++++++++++++++-------- docs/res/guides/training_wnet.rst | 16 ++++++++++ 2 files changed, 38 insertions(+), 13 deletions(-) create mode 100644 docs/res/guides/training_wnet.rst diff --git a/docs/res/guides/inference_module_guide.rst b/docs/res/guides/inference_module_guide.rst index 373e9d0d..560282ce 100644 --- a/docs/res/guides/inference_module_guide.rst +++ b/docs/res/guides/inference_module_guide.rst @@ -7,8 +7,9 @@ This module allows you to use pre-trained segmentation algorithms (written in Py to automatically label cells. .. important:: - Currently, only inference on **3D volumes is supported**. Your image and label folders should both contain a set of - **3D image files**, currently either **.tif** or **.tiff**. + Currently, only inference on **3D volumes is supported**. If using folders, your images and labels folders + should both contain a set of **3D image files**, either **.tif** or **.tiff**. + Otherwise you may run inference on layers in napari. Currently, the following pre-trained models are available : @@ -20,6 +21,7 @@ SegResNet `3D MRI brain tumor segmentation using autoencoder regularizati TRAILMAP_MS A PyTorch implementation of the `TRAILMAP project on GitHub`_ pretrained with mesoSPIM data TRAILMAP An implementation of the `TRAILMAP project on GitHub`_ using a `3DUNet for PyTorch`_ SwinUNetR `Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images`_ +WNet `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_ ============== ================================================================================================ .. _Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation: https://arxiv.org/pdf/1606.04797.pdf @@ -27,6 +29,10 @@ SwinUNetR `Swin Transformers for Semantic Segmentation of Brain Tumors i .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP .. _3DUnet for Pytorch: https://github.com/wolny/pytorch-3dunet .. _Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images: https://arxiv.org/abs/2201.01266 +.. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506 + +.. note:: + For WNet-specific instruction please refer to the appropriate section below. Interface and functionalities -------------------------------- @@ -67,8 +73,7 @@ Interface and functionalities * **Instance segmentation** : - | You can convert the semantic segmentation into instance labels by using either the `watershed`_ or `connected components`_ method. - | You can set the probability threshold from which a pixel is considered as a valid instance, as well as the minimum size in pixels for objects. All smaller objects will be removed. + | You can convert the semantic segmentation into instance labels by using either the Voronoi-Otsu, `Watershed`_ or `Connected Components`_ method, as detailed in :ref:`utils_module_guide`. | Instance labels will be saved (and shown if applicable) separately from other results. @@ -78,7 +83,7 @@ Interface and functionalities * **Computing objects statistics** : - You can choose to compute various stats from the labels and save them to a csv for later use. + You can choose to compute various stats from the labels and save them to a .csv for later use. This includes, for each object : @@ -98,13 +103,6 @@ Interface and functionalities In the ``notebooks`` folder you can find an example of plotting cell statistics using the result csv. -* **Viewing results** : - - | You can also select whether you'd like to **see the results** in napari afterwards. - | By default the first image processed will be displayed, but you can choose to display up to **ten at once**. - | You can also request to see the originals. - - When you are done choosing your parameters, you can press the **Start** button to begin the inference process. Once it has finished, results will be saved then displayed in napari; each output will be paired with its original. On the left side, a progress bar and a log will keep you informed on the process. @@ -115,7 +113,7 @@ On the left side, a progress bar and a log will keep you informed on the process | ``{original_name}_{model}_{date & time}_pred{id}.file_ext`` | For example, using a VNet on the third image of a folder, called "somatomotor.tif" will yield the following name : | *somatomotor_VNet_2022_04_06_15_49_42_pred3.tif* - | Instance labels will have the "Instance_seg" prefix appened to the name. + | Instance labels will have the "Instance_seg" prefix appended to the name. .. hint:: @@ -128,6 +126,17 @@ On the left side, a progress bar and a log will keep you informed on the process .. note:: You can save the log after the worker is finished to easily remember which parameters you ran inference with. +WNet +-------------------------------- + +The WNet model, from the paper `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_, is a fully unsupervised model that can be used to segment images without any labels. +It clusters pixels based on brightness, and can be used to segment cells in a variety of modalities. +Its use and available options are similar to the above models, with a few differences : +.. note:: + | Our provided, pre-trained model should use an input size of 64x64x64. As such, window inference is always enabled + | and set to 64. If you want to use a different size, you will have to train your own model using the provided notebook. +All it requires are images; for nucleus segmentation, it is recommended to use 2 classes (default). + Source code -------------------------------- * :doc:`../code/plugin_model_inference` diff --git a/docs/res/guides/training_wnet.rst b/docs/res/guides/training_wnet.rst new file mode 100644 index 00000000..6f0690f3 --- /dev/null +++ b/docs/res/guides/training_wnet.rst @@ -0,0 +1,16 @@ +.. _training_wnet: + +WNet model training +=================== + +This plugin provides a reimplemented version of the WNet model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_. +In order to train your own model, you may use the provided Jupyter notebook. + +The WNet uses brightness to cluster objects vs background; to get the most out of the model please use image regions with minimal +artifacts. You may then use one of the supervised models to train in order to achieve more resilient segmentation if you have many artifacts. + +The WNet should not require a very large amount of data to train, but during inference images should be similar to those +the model was trained on. + + +.. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506 From 1839f396c2e76ac26c3f4d4f634da199db7ca858 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 10 Jul 2023 11:06:55 +0200 Subject: [PATCH 575/577] Added QoL shape info for layer selecter --- napari_cellseg3d/interface.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index a73ea62d..3c43c81d 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -762,14 +762,20 @@ def __init__( self.layer_list = DropdownMenu( parent=self, text_label=name, fixed=False ) + self.layer_description = make_label("Shape:", parent=self) + self.layer_description.setVisible(False) # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? self._viewer.layers.events.inserted.connect(partial(self._add_layer)) self._viewer.layers.events.removed.connect(partial(self._remove_layer)) self.layer_list.currentIndexChanged.connect(self._update_tooltip) + self.layer_list.currentTextChanged.connect(self._update_description) - add_widgets(self.layout, [self.layer_list.label, self.layer_list]) + add_widgets( + self.layout, + [self.layer_list.label, self.layer_list, self.layer_description], + ) self._check_for_layers() def _check_for_layers(self): @@ -780,6 +786,14 @@ def _check_for_layers(self): def _update_tooltip(self): self.layer_list.setToolTip(self.layer_list.currentText()) + def _update_description(self): + if self.layer_list.currentText() != "": + self.layer_description.setVisible(True) + shape_desc = f"Shape : {self.layer_data().shape}" + self.layer_description.setText(shape_desc) + else: + self.layer_description.setVisible(False) + def _add_layer(self, event): inserted_layer = event.value @@ -803,7 +817,10 @@ def set_layer_type(self, layer_type): # no @property due to Qt constraint self._check_for_layers() def layer(self): - return self._viewer.layers[self.layer_name()] + try: + return self._viewer.layers[self.layer_name()] + except ValueError: + return None def layer_name(self): return self.layer_list.currentText() From d639b21be786c779a036bd7f34cd2177e4cfd644 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Jul 2023 08:41:10 +0200 Subject: [PATCH 576/577] WNet fixes + PR feedback improvements --- docs/res/guides/custom_model_template.rst | 2 +- docs/res/guides/training_wnet.rst | 28 ++++++++++++++++--- napari_cellseg3d/_tests/test_models.py | 2 +- .../_tests/test_plugin_inference.py | 2 ++ .../code_models/models/model_WNet.py | 6 ++-- .../code_models/models/wnet/model.py | 10 +++++-- 6 files changed, 39 insertions(+), 11 deletions(-) diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index a70df29b..b7eb65e3 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -4,7 +4,7 @@ Advanced : Declaring a custom model ============================================= .. warning:: - **WIP** : Adding new models is still a work in progress and will likely not work simply by adding the model in the plugin. + **WIP** : Adding new models is still a work in progress and will likely not work out of the box, leading to errors. Please `file an issue`_ if you would like to add a custom model and we will help you get it working. diff --git a/docs/res/guides/training_wnet.rst b/docs/res/guides/training_wnet.rst index 6f0690f3..ecd20542 100644 --- a/docs/res/guides/training_wnet.rst +++ b/docs/res/guides/training_wnet.rst @@ -3,14 +3,34 @@ WNet model training =================== -This plugin provides a reimplemented version of the WNet model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_. -In order to train your own model, you may use the provided Jupyter notebook. +This plugin provides a reimplemented, custom version of the WNet model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_. +In order to train your own model, you may use the provided Jupyter notebook; support for in-plugin training might be added in the future. The WNet uses brightness to cluster objects vs background; to get the most out of the model please use image regions with minimal -artifacts. You may then use one of the supervised models to train in order to achieve more resilient segmentation if you have many artifacts. +artifacts. You may then train one of the supervised models in order to achieve more resilient segmentation if you have many artifacts. The WNet should not require a very large amount of data to train, but during inference images should be similar to those -the model was trained on. +the model was trained on; you can retrain from our pretrained model to your set of images to quickly reach good performance. + +The model has two losses, the SoftNCut loss which clusters pixels according to brightness, and a reconstruction loss, either +Mean Square Error (MSE) or Binary Cross Entropy (BCE). +Unlike the original paper, these losses are added in a weighted sum and the backward pass is performed for the whole model at once. +The SoftNcuts is bounded between 0 and 1; the MSE may take large values. + +For good performance, one should wait for the SoftNCut to reach a plateau, the reconstruction loss must also diminish but it's generally less critical. + + +Common issues troubleshooting +------------------------------ +If you do not find a satisfactory answer here, please `open an issue`_ ! + +- **The NCuts loss explodes after a few epochs** : Lower the learning rate + +- **The NCuts loss does not converge and is unstable** : + The normalization step might not be adapted to your images. Disable normalization and change intensity_sigma according to the distribution of values in your image; for reference, by default images are remapped to values between 0 and 100, and intensity_sigma=1. + +- **Reconstruction (decoder) performance is poor** : switch to BCE and set the scaling factor of the reconstruction loss ot 0.5, OR adjust the weight of the MSE loss to make it closer to 1. .. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506 +.. _open an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py index 3845bb6a..ebb3a614 100644 --- a/napari_cellseg3d/_tests/test_models.py +++ b/napari_cellseg3d/_tests/test_models.py @@ -49,7 +49,7 @@ def test_soft_ncuts_loss(): res = loss.forward(labels, labels) assert isinstance(res, torch.Tensor) - assert 0 <= res <= 1 + assert 0 <= res <= 1 # ASSUMES NUMBER OF CLASS IS 2, NOT CORRECT IF K>2 def test_crf_batch(): diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 779f5094..0258f243 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -47,6 +47,7 @@ def test_inference(make_napari_viewer, qtbot): widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None assert widget.model_info is not None + widget.window_infer_box.setChecked(False) worker = widget._create_worker_from_config(widget.worker_config) assert worker.config is not None @@ -63,5 +64,6 @@ def test_inference(make_napari_viewer, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) assert res.result.shape == (8, 8, 8) + assert res.instance_labels.shape == (8, 8, 8) widget.on_yield(res) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index 62142e73..a2fce724 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -9,8 +9,8 @@ class WNet_(WNet_encoder): def __init__( self, in_channels=1, - out_channels=1, - num_classes=2, + out_channels=2, + # num_classes=2, device="cpu", **kwargs, ): @@ -18,7 +18,7 @@ def __init__( device=device, in_channels=in_channels, out_channels=out_channels, - num_classes=num_classes, + # num_classes=num_classes, ) # def train(self: T, mode: bool = True) -> T: diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index e1dfbec8..5ef726b6 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -21,7 +21,13 @@ class WNet_encoder(nn.Module): """WNet with encoder only.""" - def __init__(self, device, in_channels=1, out_channels=1, num_classes=2): + def __init__( + self, + device, + in_channels=1, + out_channels=2 + # num_classes=2 + ): super().__init__() self.device = device self.encoder = UNet( @@ -44,7 +50,7 @@ class WNet(nn.Module): def __init__( self, in_channels=1, - out_channels=1, + out_channels=2, num_classes=2, dropout=0.65, ): From 308370514cbb5c0e3001e3ab6de04d5bae8bd548 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 12 Jul 2023 09:19:50 +0200 Subject: [PATCH 577/577] Added imagecodecs to open external datasets --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0e82ed6d..c188c9c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "matplotlib>=3.4.1", "tifffile>=2022.2.9", "imageio-ffmpeg>=0.4.5", + "imagecodecs>=2023.3.16", "torch>=1.11", "monai[nibabel,einops]>=0.9.0", "itk",