deepflash2
Advanced tools
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_config.ipynb (unless otherwise specified). | ||
| __all__ = ['Config'] | ||
| # Cell | ||
| from dataclasses import dataclass, asdict | ||
| from pathlib import Path | ||
| import json | ||
| import torch | ||
| # Cell | ||
| @dataclass | ||
| class Config: | ||
| "Config class for settings." | ||
| # Project | ||
| project_dir:str = '.' | ||
| # GT Estimation Settings | ||
| # staple_thres:float = 0.5 | ||
| # staple_fval:int= 1 | ||
| vote_undec:int = 0 | ||
| # Train General Settings | ||
| n_models:int = 5 | ||
| max_splits:int=5 | ||
| random_state:int = 42 | ||
| use_gpu:bool = True | ||
| # Pytorch Segmentation Model Settings | ||
| arch:str = 'Unet' | ||
| encoder_name:str = 'tu-convnext_tiny' | ||
| encoder_weights:str = 'imagenet' | ||
| # Train Data Settings | ||
| num_classes:int = 2 | ||
| tile_shape:int = 512 | ||
| scale:float = 1. | ||
| instance_labels:bool = False | ||
| # Train Settings | ||
| base_lr:float = 0.001 | ||
| batch_size:int = 4 | ||
| weight_decay:float = 0.001 | ||
| mixed_precision_training:bool = True | ||
| optim:str = 'Adam' | ||
| loss:str = 'CrossEntropyDiceLoss' | ||
| n_epochs:int = 25 | ||
| sample_mult:int = 0 | ||
| # Train Data Augmentation | ||
| gamma_limit_lower:int = 80 | ||
| gamma_limit_upper:int = 120 | ||
| CLAHE_clip_limit:float = 0.0 | ||
| brightness_limit:float = 0.0 | ||
| contrast_limit:float = 0.0 | ||
| flip:bool = True | ||
| rot:int = 360 | ||
| distort_limit:float = 0 | ||
| # Loss Settings | ||
| mode:str = 'multiclass' #currently only tested for multiclass | ||
| loss_alpha:float = 0.5 # Twerksky/Focal loss | ||
| loss_beta:float = 0.5 # Twerksy Loss | ||
| loss_gamma:float = 2.0 # Focal loss | ||
| loss_smooth_factor:float = 0. #SoftCrossEntropyLoss | ||
| # Pred/Val Settings | ||
| use_tta:bool = True | ||
| max_tile_shift: float = 0.5 | ||
| border_padding_factor:float = 0.25 | ||
| use_gaussian: bool = True | ||
| gaussian_kernel_sigma_scale: float = 0.125 | ||
| min_pixel_export:int = 0 | ||
| # Instance Segmentation Settings | ||
| cellpose_model:str='nuclei' | ||
| cellpose_diameter:int=0 | ||
| cellpose_export_class:int=1 | ||
| instance_segmentation_metrics:bool=False | ||
| # Folder Structure | ||
| gt_dir:str = 'GT_Estimation' | ||
| train_dir:str = 'Training' | ||
| pred_dir:str = 'Prediction' | ||
| ens_dir:str = 'models' | ||
| val_dir:str = 'valid' | ||
| def __post_init__(self): | ||
| self.set_device() | ||
| def set_device(self, device:str=None): | ||
| if device is None: | ||
| self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | ||
| else: | ||
| self.device = device | ||
| @property | ||
| def albumentation_kwargs(self): | ||
| kwargs = ['gamma_limit_lower', 'gamma_limit_upper', 'CLAHE_clip_limit', | ||
| 'brightness_limit', 'contrast_limit', 'distort_limit'] | ||
| return dict(filter(lambda x: x[0] in kwargs, self.__dict__.items())) | ||
| @property | ||
| def inference_kwargs(self): | ||
| inference_kwargs = ['use_tta', 'max_tile_shift', 'use_gaussian', 'scale', | ||
| 'gaussian_kernel_sigma_scale', 'border_padding_factor'] | ||
| return dict(filter(lambda x: x[0] in inference_kwargs, self.__dict__.items())) | ||
| def save(self, path): | ||
| 'Save configuration to path' | ||
| path = Path(path).with_suffix('.json') | ||
| with open(path, 'w') as config_file: | ||
| json.dump(asdict(self), config_file) | ||
| print(f'Saved current configuration to {path}.json') | ||
| return path | ||
| def load(self, path): | ||
| 'Load configuration from path' | ||
| path = Path(path) | ||
| try: | ||
| with open(path) as config_file: c = json.load(config_file) | ||
| if not Path(c['project_dir']).is_dir(): c['project_dir']='deepflash2' | ||
| for k,v in c.items(): setattr(self, k, v) | ||
| print(f'Successsfully loaded configuration from {path}') | ||
| except: | ||
| print('Error! Select valid config file (.json)') |
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/04_inference.ipynb (unless otherwise specified). | ||
| __all__ = ['torch_gaussian', 'gaussian_kernel_2d', 'epistemic_uncertainty', 'aleatoric_uncertainty', 'uncertainty', | ||
| 'get_in_slices_1d', 'get_out_slices_1d', 'TileModule', 'InferenceEnsemble'] | ||
| # Cell | ||
| from typing import Tuple, List | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from torchvision.transforms import Normalize | ||
| import deepflash2.tta as tta | ||
| # Cell | ||
| # adapted from https://github.com/scipy/scipy/blob/f2ec91c4908f9d67b5445fbfacce7f47518b35d1/scipy/signal/windows.py#L976 | ||
| @torch.jit.script | ||
| def torch_gaussian(M:int, std:float, sym:bool=True) ->torch.Tensor: | ||
| 'Returns a Gaussian window' | ||
| assert M > 2, 'Kernel size must be greater than 2.' | ||
| odd = M % 2 | ||
| if not sym and not odd: | ||
| M = M + 1 | ||
| n = torch.arange(0, M) - (M - 1.0) / 2.0 | ||
| sig2 = 2 * std * std | ||
| w = torch.exp(-n ** 2 / sig2) | ||
| if not sym and not odd: | ||
| w = w[:-1] | ||
| return w | ||
| # Cell | ||
| @torch.jit.script | ||
| def gaussian_kernel_2d(patch_size: Tuple[int, int], sigma_scale:float=1/8) ->torch.Tensor: | ||
| 'Returns a 2D Gaussian kernel tensor.' | ||
| patch_size = [patch_size[0], patch_size[1]] | ||
| sigmas = [i * sigma_scale for i in patch_size] | ||
| gkern1ds = [torch_gaussian(kernlen, std=std) for kernlen, std in zip(patch_size, sigmas)] | ||
| gkern2d = torch.outer(gkern1ds[0], gkern1ds[1]) | ||
| gkern2d = gkern2d / gkern2d.max() | ||
| gkern2d[gkern2d==0] = gkern2d.min() | ||
| return gkern2d | ||
| # Cell | ||
| # adapted from https://github.com/ykwon0407/UQ_BNN/blob/master/retina/utils.py | ||
| @torch.jit.script | ||
| def epistemic_uncertainty(x: torch.Tensor): | ||
| return torch.mean(x**2, dim=0) - torch.mean(x, dim=0)**2 | ||
| @torch.jit.script | ||
| def aleatoric_uncertainty(x: torch.Tensor): | ||
| return torch.mean(x * (1 - x), dim=0) | ||
| @torch.jit.script | ||
| def uncertainty(x: torch.Tensor): | ||
| # Add uncertainties | ||
| uncertainty = epistemic_uncertainty(x) + aleatoric_uncertainty(x) | ||
| # Scale to 1 max overall | ||
| uncertainty /= 0.25 | ||
| return uncertainty | ||
| # Cell | ||
| @torch.jit.export | ||
| def get_in_slices_1d(center:torch.Tensor, len_x:int, len_tile:int) ->torch.Tensor: | ||
| start = (len_tile/2-center).clip(0).to(torch.int64) | ||
| stop = torch.tensor(len_tile).clip(max=(len_x-center+len_tile/2).to(torch.int64)) | ||
| return torch.stack((start, stop)) | ||
| @torch.jit.export | ||
| def get_out_slices_1d(center:torch.Tensor, len_x:int, len_tile:int) ->torch.Tensor: | ||
| start = (center - (len_tile/2)).clip(0, len_x).to(torch.int64) | ||
| stop = (center + (len_tile/2)).clip(max=len_x).to(torch.int64) | ||
| return torch.stack((start, stop)) | ||
| # Cell | ||
| class TileModule(torch.nn.Module): | ||
| "Class for tiling data." | ||
| def __init__(self, | ||
| tile_shape = (512, 512), | ||
| scale:float = 1., | ||
| border_padding_factor:float = 0.25, | ||
| max_tile_shift:float = 0.5): | ||
| super(TileModule, self).__init__() | ||
| self.tile_shape = tile_shape | ||
| self.scale = scale | ||
| self.border_padding_factor = border_padding_factor | ||
| self.max_tile_shift = max_tile_shift | ||
| grid_range = [torch.linspace(-self.scale, self.scale, steps=d) for d in tile_shape] | ||
| self.deformationField = torch.meshgrid(*grid_range, indexing='ij') | ||
| @torch.jit.export | ||
| def get_centers_1d(self, len_x:int, len_tile:int) ->torch.Tensor: | ||
| len_padding = float(len_tile*self.border_padding_factor) | ||
| start_point = len_tile/2 - len_padding | ||
| end_point = len_x - start_point | ||
| #end_point = max(len_x - start_point, start_point) | ||
| n_points = int((len_x+2*len_padding)//(len_tile*self.max_tile_shift))+1 | ||
| return torch.linspace(start_point, end_point, n_points, dtype=torch.int64) | ||
| @torch.jit.export | ||
| def get_center_combinations(self, shape:List[int]) ->torch.Tensor: | ||
| c_list = [self.get_centers_1d(shape[i], self.tile_shape[i]) for i in range(2)] | ||
| center_combinations = torch.meshgrid(c_list[0],c_list[1], indexing='ij') | ||
| return torch.stack(center_combinations).permute([2, 1, 0]).reshape(-1,2) | ||
| @torch.jit.export | ||
| def get_slices_and_centers(self, shape:List[int]) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: | ||
| shape = [int(shape[i]/self.scale) for i in range(2)] | ||
| center_combinations = self.get_center_combinations(shape) | ||
| in_slices = [get_in_slices_1d(center_combinations[:,i], shape[i], self.tile_shape[i]) for i in range(2)] | ||
| out_slices = [get_out_slices_1d(center_combinations[:,i], shape[i], self.tile_shape[i]) for i in range(2)] | ||
| scaled_centers = (center_combinations*self.scale).type(torch.int64) | ||
| return in_slices, out_slices, scaled_centers | ||
| @torch.jit.export | ||
| def forward(self, x, center:torch.Tensor) ->torch.Tensor: | ||
| "Apply deformation field to image using interpolation" | ||
| # Align grid to relative position and scale | ||
| grids = [] | ||
| for i in range(2): | ||
| s = x.shape[i] | ||
| scale_ratio = self.tile_shape[i]/s | ||
| relative_center = (center[i]-s/2)/(s/2) | ||
| coords = (self.deformationField[i]*scale_ratio)+relative_center | ||
| grids.append(coords.to(x)) | ||
| # grid with shape (N, H, W, 2) | ||
| vgrid = torch.stack(grids[::-1], dim=-1).to(x).unsqueeze_(0) | ||
| # input with shape (N, C, H, W) | ||
| x = x.permute(2,0,1).unsqueeze_(0) | ||
| # Remap | ||
| x = torch.nn.functional.grid_sample(x, vgrid, mode='nearest', padding_mode='reflection', align_corners=False) | ||
| return x | ||
| # Cell | ||
| class InferenceEnsemble(torch.nn.Module): | ||
| 'Class for model ensemble inference' | ||
| def __init__(self, | ||
| models:List[torch.nn.Module], | ||
| num_classes:int, | ||
| in_channels:int, | ||
| channel_means:List[float], | ||
| channel_stds:List[float], | ||
| tile_shape:Tuple[int,int]=(512,512), | ||
| use_gaussian: bool = True, | ||
| gaussian_kernel_sigma_scale:float = 1./8, | ||
| use_tta:bool=True, | ||
| border_padding_factor:float = 0.25, | ||
| max_tile_shift:float = 0.9, | ||
| scale:float = 1., | ||
| device:str='cpu'): | ||
| super().__init__() | ||
| self.num_classes = num_classes | ||
| self.use_tta = use_tta | ||
| self.use_gaussian = use_gaussian | ||
| self.gaussian_kernel_sigma_scale = gaussian_kernel_sigma_scale | ||
| self.tile_shape = tile_shape | ||
| self.norm = Normalize(channel_means, channel_stds) | ||
| dummy_input = torch.rand(1, in_channels, *self.tile_shape).to(device) | ||
| self.models = torch.nn.ModuleList([torch.jit.trace(m.to(device).eval(), dummy_input) for m in models]) | ||
| self.tiler = torch.jit.script(TileModule(tile_shape=tile_shape, | ||
| scale=scale, | ||
| border_padding_factor=border_padding_factor, | ||
| max_tile_shift=max_tile_shift)) | ||
| mw = gaussian_kernel_2d(tile_shape, gaussian_kernel_sigma_scale) if use_gaussian else torch.ones(tile_shape[0], tile_shape[1]) | ||
| self.register_buffer('mw', mw) | ||
| tfms = [tta.HorizontalFlip(),tta.VerticalFlip()] if self.use_tta else [] | ||
| self.tta_tfms = tta.Compose(tfms) | ||
| def forward(self, x): | ||
| # Extract image shape (assuming HWC) | ||
| sh = x.shape[:-1] | ||
| # Workaround for sh_scaled = [int(s/self.tiler.scale) for s in img_shape] | ||
| sh_scaled = (torch.tensor(sh)/self.tiler.scale).to(torch.int64) | ||
| sh_scaled = [int(t.item()) for t in sh_scaled] | ||
| # Create zero arrays (only on CPU RAM to avoid GPU memory overflow on large images) | ||
| # softmax = torch.zeros((sh_scaled[0], sh_scaled[1], self.num_classes), dtype=torch.float32, device=x.device) | ||
| softmax = torch.zeros((self.num_classes, sh_scaled[0], sh_scaled[1]), dtype=torch.float32, device=x.device) | ||
| merge_map = torch.zeros((sh_scaled[0], sh_scaled[1]), dtype=torch.float32, device=x.device) | ||
| stdeviation = torch.zeros((sh_scaled[0], sh_scaled[1]), dtype=torch.float32, device=x.device) | ||
| # Get slices for tiling | ||
| in_slices, out_slices, center_points = self.tiler.get_slices_and_centers(sh) | ||
| # | ||
| self.mw.to(x) | ||
| # Loop over tiles | ||
| for i, cp in enumerate(center_points): | ||
| tile = self.tiler(x, cp) | ||
| # Normalize | ||
| tile = self.norm(tile) | ||
| smxs = [] | ||
| # Loop over tt-augmentations | ||
| for t in self.tta_tfms.items: | ||
| aug_tile = t.augment(tile) | ||
| # Loop over models | ||
| for model in self.models: | ||
| logits = model(aug_tile) | ||
| logits = t.deaugment(logits) | ||
| smxs.append(F.softmax(logits, dim=1)) | ||
| smxs = torch.stack(smxs) | ||
| ix0, ix1, iy0, iy1 = in_slices[0][0][i], in_slices[0][1][i], in_slices[1][0][i], in_slices[1][1][i] | ||
| ox0, ox1, oy0, oy1 = out_slices[0][0][i], out_slices[0][1][i], out_slices[1][0][i], out_slices[1][1][i] | ||
| # Apply weigthing | ||
| batch_smx = torch.mean(smxs, dim=0)*self.mw.view(1,1,self.mw.shape[0],self.mw.shape[1]) | ||
| #softmax[ox0:ox1, oy0:oy1] += batch_smx.permute(0,2,3,1)[0][ix0:ix1, iy0:iy1].to(softmax) | ||
| softmax[..., ox0:ox1, oy0:oy1] += batch_smx[0, ..., ix0:ix1, iy0:iy1].to(softmax) | ||
| merge_map[ox0:ox1, oy0:oy1] += self.mw[ix0:ix1, iy0:iy1].to(merge_map) | ||
| # Encertainty_estimates | ||
| batch_std = torch.mean(uncertainty(smxs), dim=1)*self.mw.view(1,self.mw.shape[0],self.mw.shape[1]) | ||
| stdeviation[ox0:ox1, oy0:oy1] += batch_std[0][ix0:ix1, iy0:iy1].to(stdeviation) | ||
| # Normalize weighting | ||
| softmax /= torch.unsqueeze(merge_map, 0) | ||
| stdeviation /= merge_map | ||
| # Rescale results | ||
| if self.tiler.scale!=1.: | ||
| # Needs checking if these are the best options | ||
| softmax = F.interpolate(softmax.unsqueeze_(0), size=sh, mode="bilinear", align_corners=False)[0] | ||
| stdeviation = stdeviation.view(1, 1, stdeviation.shape[0], stdeviation.shape[1]) | ||
| stdeviation = F.interpolate(stdeviation, size=sh, mode="bilinear", align_corners=False)[0][0] | ||
| argmax = torch.argmax(softmax, dim=0).to(torch.uint8) | ||
| return argmax, softmax, stdeviation |
| Metadata-Version: 2.1 | ||
| Name: deepflash2 | ||
| Version: 0.1.8 | ||
| Version: 0.2.0 | ||
| Summary: A Deep learning pipeline for segmentation of fluorescent labels in microscopy images | ||
@@ -10,3 +10,2 @@ Home-page: https://github.com/matjesg/deepflash2 | ||
| Keywords: unet,deep learning,semantic segmentation,microscopy,fluorescent labels | ||
| Platform: UNKNOWN | ||
| Classifier: Development Status :: 3 - Alpha | ||
@@ -16,10 +15,8 @@ Classifier: Intended Audience :: Developers | ||
| Classifier: Natural Language :: English | ||
| Classifier: Programming Language :: Python :: 3.6 | ||
| Classifier: Programming Language :: Python :: 3.7 | ||
| Classifier: Programming Language :: Python :: 3.8 | ||
| Requires-Python: >=3.6 | ||
| Requires-Python: >=3.7 | ||
| Description-Content-Type: text/markdown | ||
| License-File: LICENSE | ||
| # Welcome to | ||
@@ -180,3 +177,1 @@ | ||
| The ImagJ-Macro is available [here](https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/ImageJ/Macro_create_maps.ijm). | ||
@@ -9,6 +9,8 @@ pip | ||
| openpyxl | ||
| imagecodecs | ||
| albumentations>=1.0.0 | ||
| natsort>=7.1.1 | ||
| numba>=0.52.0 | ||
| segmentation-models-pytorch>=0.2 | ||
| opencv-python-headless<4.5.5,>=4.1.1 | ||
| timm>=0.5.4 | ||
| segmentation-models-pytorch-deepflash2 |
@@ -10,5 +10,7 @@ CONTRIBUTING.md | ||
| deepflash2/all.py | ||
| deepflash2/config.py | ||
| deepflash2/data.py | ||
| deepflash2/gt.py | ||
| deepflash2/gui.py | ||
| deepflash2/inference.py | ||
| deepflash2/learner.py | ||
@@ -22,5 +24,4 @@ deepflash2/losses.py | ||
| deepflash2.egg-info/dependency_links.txt | ||
| deepflash2.egg-info/entry_points.txt | ||
| deepflash2.egg-info/not-zip-safe | ||
| deepflash2.egg-info/requires.txt | ||
| deepflash2.egg-info/top_level.txt |
@@ -1,1 +0,1 @@ | ||
| __version__ = "0.1.8" | ||
| __version__ = "0.2.0" |
+27
-13
@@ -5,9 +5,8 @@ # AUTOGENERATED BY NBDEV! DO NOT EDIT! | ||
| index = {"Config": "00_learner.ipynb", | ||
| "energy_score": "00_learner.ipynb", | ||
| "EnsemblePredict": "00_learner.ipynb", | ||
| "EnsembleLearner": "00_learner.ipynb", | ||
| index = {"Config": "00_config.ipynb", | ||
| "ARCHITECTURES": "01_models.ipynb", | ||
| "ENCODERS": "01_models.ipynb", | ||
| "get_pretrained_options": "01_models.ipynb", | ||
| "smp.decoders.unet.decoder.UnetDecoder.forward": "01_models.ipynb", | ||
| "PATCH_UNET_DECODER": "01_models.ipynb", | ||
| "create_smp_model": "01_models.ipynb", | ||
@@ -22,5 +21,18 @@ "save_smp_model": "01_models.ipynb", | ||
| "DeformationField": "02_data.ipynb", | ||
| "tiles_in_rectangles": "02_data.ipynb", | ||
| "BaseDataset": "02_data.ipynb", | ||
| "RandomTileDataset": "02_data.ipynb", | ||
| "TileDataset": "02_data.ipynb", | ||
| "EnsembleBase": "03_learner.ipynb", | ||
| "EnsembleLearner": "03_learner.ipynb", | ||
| "EnsemblePredictor": "03_learner.ipynb", | ||
| "torch_gaussian": "04_inference.ipynb", | ||
| "gaussian_kernel_2d": "04_inference.ipynb", | ||
| "epistemic_uncertainty": "04_inference.ipynb", | ||
| "aleatoric_uncertainty": "04_inference.ipynb", | ||
| "uncertainty": "04_inference.ipynb", | ||
| "get_in_slices_1d": "04_inference.ipynb", | ||
| "get_out_slices_1d": "04_inference.ipynb", | ||
| "TileModule": "04_inference.ipynb", | ||
| "InferenceEnsemble": "04_inference.ipynb", | ||
| "LOSSES": "05_losses.ipynb", | ||
@@ -30,2 +42,3 @@ "FastaiLoss": "05_losses.ipynb", | ||
| "JointLoss": "05_losses.ipynb", | ||
| "Poly1CrossEntropyLoss": "05_losses.ipynb", | ||
| "get_loss": "05_losses.ipynb", | ||
@@ -37,6 +50,7 @@ "unzip": "06_utils.ipynb", | ||
| "compose_albumentations": "06_utils.ipynb", | ||
| "ensemble_results": "06_utils.ipynb", | ||
| "clean_show": "06_utils.ipynb", | ||
| "plot_results": "06_utils.ipynb", | ||
| "Recorder.plot_metrics": "06_utils.ipynb", | ||
| "iou": "06_utils.ipynb", | ||
| "multiclass_dice_score": "06_utils.ipynb", | ||
| "binary_dice_score": "06_utils.ipynb", | ||
| "dice_score": "06_utils.ipynb", | ||
@@ -54,9 +68,8 @@ "label_mask": "06_utils.ipynb", | ||
| "BaseTransform": "07_tta.ipynb", | ||
| "HorizontalFlip": "07_tta.ipynb", | ||
| "VerticalFlip": "07_tta.ipynb", | ||
| "Rotate90": "07_tta.ipynb", | ||
| "Chain": "07_tta.ipynb", | ||
| "Transformer": "07_tta.ipynb", | ||
| "Compose": "07_tta.ipynb", | ||
| "Merger": "07_tta.ipynb", | ||
| "HorizontalFlip": "07_tta.ipynb", | ||
| "VerticalFlip": "07_tta.ipynb", | ||
| "Rotate90": "07_tta.ipynb", | ||
| "GRID_COLS": "08_gui.ipynb", | ||
@@ -100,10 +113,11 @@ "COLS_PRED_KEEP": "08_gui.ipynb", | ||
| "import_sitk": "09_gt.ipynb", | ||
| "staple": "09_gt.ipynb", | ||
| "staple_multi_label": "09_gt.ipynb", | ||
| "m_voting": "09_gt.ipynb", | ||
| "msk_show": "09_gt.ipynb", | ||
| "GTEstimator": "09_gt.ipynb"} | ||
| modules = ["learner.py", | ||
| modules = ["config.py", | ||
| "models.py", | ||
| "data.py", | ||
| "learner.py", | ||
| "inference.py", | ||
| "losses.py", | ||
@@ -110,0 +124,0 @@ "utils.py", |
@@ -0,1 +1,2 @@ | ||
| from .config import * | ||
| from .learner import * | ||
@@ -2,0 +3,0 @@ from .data import * |
+122
-95
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02_data.ipynb (unless otherwise specified). | ||
| __all__ = ['show', 'preprocess_mask', 'DeformationField', 'BaseDataset', 'RandomTileDataset', 'TileDataset'] | ||
| __all__ = ['show', 'preprocess_mask', 'DeformationField', 'tiles_in_rectangles', 'BaseDataset', 'RandomTileDataset', | ||
| 'TileDataset'] | ||
@@ -23,6 +24,9 @@ # Cell | ||
| from fastcore.all import * | ||
| from fastprogress import progress_bar | ||
| from .utils import clean_show | ||
| # Cell | ||
| def show(*obj, file_name=None, overlay=False, pred=False, | ||
| show_bbox=True, figsize=(10,10), cmap='binary_r', **kwargs): | ||
| def show(*obj, file_name=None, overlay=False, pred=False, num_classes=2, | ||
| show_bbox=False, figsize=(10,10), cmap='viridis', **kwargs): | ||
| "Show image, mask, and weight (optional)" | ||
@@ -63,5 +67,2 @@ if len(obj)==3: | ||
| if cmap is None: | ||
| cmap = 'binary_r' if msk.max()==1 else cmap | ||
| # Weights preprocessing | ||
@@ -78,8 +79,4 @@ if weight is not None: | ||
| # Plot img | ||
| img_ax.imshow(img, cmap=cmap) | ||
| if file_name is not None: | ||
| img_ax.set_title('Image {}'.format(file_name)) | ||
| else: | ||
| img_ax.set_title('Image') | ||
| img_ax.set_axis_off() | ||
| img_title = f'Image {file_name}' if file_name is not None else 'Image' | ||
| clean_show(img_ax, img, img_title, cmap) | ||
@@ -91,7 +88,7 @@ # Plot img and mask | ||
| img_l2o = label2rgb(label_image, image=img, bg_label=0, alpha=.8, image_alpha=1) | ||
| ax[1].set_title('Image + Mask (#ROIs: {})'.format(label_image.max())) | ||
| ax[1].imshow(img_l2o) | ||
| pred_title = 'Image + Mask (#ROIs: {})'.format(label_image.max()) | ||
| clean_show(ax[1], img_l2o, pred_title, None) | ||
| else: | ||
| ax[1].imshow(msk, cmap=cmap) | ||
| ax[1].set_title('Mask') | ||
| vkwargs = {'vmin':0, 'vmax':num_classes-1} | ||
| clean_show(ax[1], msk, 'Mask', cmap, cbar='classes', ticks=num_classes, **vkwargs) | ||
| if show_bbox: ax[1].add_patch(copy(bbox)) | ||
@@ -135,3 +132,3 @@ | ||
| # adapted from Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70. | ||
| def preprocess_mask(clabels=None, instlabels=None, remove_connectivity=True, n_classes = 2): | ||
| def preprocess_mask(clabels=None, instlabels=None, remove_connectivity=True, num_classes = 2): | ||
| "Calculates the weights from the given mask (classlabels `clabels` or `instlabels`)." | ||
@@ -170,3 +167,3 @@ | ||
| # of that class, avoid overlapping instances | ||
| dil = cv2.morphologyEx(il, cv2.MORPH_CLOSE, kernel=np.ones((3,) * n_classes)) | ||
| dil = cv2.morphologyEx(il, cv2.MORPH_CLOSE, kernel=np.ones((3,) * num_classes)) | ||
| overlap_cand = np.unique(np.where(dil!=il, dil, 0)) | ||
@@ -176,3 +173,3 @@ labels[np.isin(il, overlap_cand, invert=True)] = c | ||
| for instance in overlap_cand[1:]: | ||
| objectMaskDil = cv2.dilate((labels == c).astype('uint8'), kernel=np.ones((3,) * n_classes),iterations = 1) | ||
| objectMaskDil = cv2.dilate((labels == c).astype('uint8'), kernel=np.ones((3,) * num_classes),iterations = 1) | ||
| labels[(instlabels == instance) & (objectMaskDil == 0)] = c | ||
@@ -263,3 +260,3 @@ else: | ||
| def _read_img(path, **kwargs): | ||
| "Read image and normalize to 0-1 range" | ||
| "Read image" | ||
| if path.suffix == '.zarr': | ||
@@ -269,4 +266,4 @@ img = zarr.convenience.open(path.as_posix()) | ||
| img = imageio.imread(path, **kwargs) | ||
| if img.max()>1.: | ||
| img = img/np.iinfo(img.dtype).max | ||
| #if img.max()>1.: | ||
| # img = img/np.iinfo(img.dtype).max | ||
| if img.ndim == 2: | ||
@@ -277,3 +274,3 @@ img = np.expand_dims(img, axis=2) | ||
| # Cell | ||
| def _read_msk(path, n_classes=2, instance_labels=False, remove_connectivity=True, **kwargs): | ||
| def _read_msk(path, num_classes=2, instance_labels=False, remove_connectivity=True, **kwargs): | ||
| "Read image and check classes" | ||
@@ -285,6 +282,6 @@ if path.suffix == '.zarr': | ||
| if instance_labels: | ||
| msk = preprocess_mask(clabels=None, instlabels=msk, remove_connectivity=remove_connectivity, n_classes=n_classes) | ||
| msk = preprocess_mask(clabels=None, instlabels=msk, remove_connectivity=remove_connectivity, num_classes=num_classes) | ||
| else: | ||
| # handle binary labels that are scaled different from 0 and 1 | ||
| if n_classes==2 and np.max(msk)>1 and len(np.unique(msk))==2: | ||
| if num_classes==2 and np.max(msk)>1 and len(np.unique(msk))==2: | ||
| msk = msk//np.max(msk) | ||
@@ -296,3 +293,3 @@ # Remove channels if no extra information given | ||
| # Mask check | ||
| assert len(np.unique(msk))<=n_classes, f'Expected mask with {n_classes} classes but got mask with {len(np.unique(msk))} classes {np.unique(msk)} . Are you using instance labels?' | ||
| assert len(np.unique(msk))<=num_classes, f'Expected mask with {num_classes} classes but got mask with {len(np.unique(msk))} classes {np.unique(msk)} . Are you using instance labels?' | ||
| assert len(msk.shape)==2, 'Currently, only masks with a single channel are supported.' | ||
@@ -302,10 +299,23 @@ return msk.astype('uint8') | ||
| # Cell | ||
| def tiles_in_rectangles(H, W, h, w): | ||
| '''Get smaller rectangles needed to fill the larger rectangle''' | ||
| n_H = math.ceil(float(H)/float(h)) | ||
| n_W = math.ceil(float(W)/float(w)) | ||
| return n_H*n_W | ||
| # Cell | ||
| class BaseDataset(Dataset): | ||
| def __init__(self, files, label_fn=None, instance_labels = False, n_classes=2, ignore={},remove_connectivity=True,stats=None,normalize=True, | ||
| def __init__(self, files, label_fn=None, instance_labels = False, num_classes=2, ignore={},remove_connectivity=True, | ||
| stats=None,normalize=True, use_zarr_data=True, | ||
| tile_shape=(512,512), padding=(0,0),preproc_dir=None, verbose=1, scale=1, pdf_reshape=512, use_preprocessed_labels=False, **kwargs): | ||
| store_attr('files, label_fn, instance_labels, n_classes, ignore, tile_shape, remove_connectivity, padding, preproc_dir, normalize, scale, pdf_reshape, use_preprocessed_labels') | ||
| self.c = n_classes | ||
| store_attr('files, label_fn, instance_labels, num_classes, ignore, tile_shape, remove_connectivity, padding, preproc_dir, stats, normalize, scale, pdf_reshape, use_preprocessed_labels') | ||
| self.c = num_classes | ||
| self.use_zarr_data=False | ||
| if self.normalize: | ||
| self.stats = stats or self.compute_stats() | ||
| if self.stats is None: | ||
| self.mean_sum, self.var_sum = 0., 0. | ||
| self.max_tile_count = 0 | ||
| self.actual_tile_shape = (np.array(self.tile_shape)-np.array(self.padding)) | ||
@@ -315,7 +325,9 @@ if label_fn is not None: | ||
| root = zarr.group(store=self.preproc_dir, overwrite= not use_preprocessed_labels) | ||
| self.labels, self.pdfs = root.require_groups('labels','pdfs') | ||
| self._preproc(verbose) | ||
| self.data, self.labels, self.pdfs = root.require_groups('data', 'labels','pdfs') | ||
| self._preproc(use_zarr_data=use_zarr_data, verbose=verbose) | ||
| def read_img(self, *args, **kwargs): | ||
| return _read_img(*args, **kwargs) | ||
| def read_img(self, path, **kwargs): | ||
| if self.use_zarr_data: img = self.data[path.name] | ||
| else: img = _read_img(path, **kwargs) | ||
| return img | ||
@@ -325,10 +337,18 @@ def read_mask(self, *args, **kwargs): | ||
| def _create_cdf(self, mask, ignore, fbr=None): | ||
| def _create_cdf(self, mask, ignore, sampling_weights=None, igonore_edges_pct=0): | ||
| 'Creates a cumulated probability density function (CDF) for weighted sampling ' | ||
| # Create probability density function | ||
| # Create mask | ||
| mask = mask[:] | ||
| fbr = fbr or np.sum(mask>0)/np.sum(mask==0) | ||
| pdf = (mask>0) + (mask==0) * fbr | ||
| if sampling_weights is None: | ||
| classes, counts = np.unique(mask, return_counts=True) | ||
| sampling_weights = {k:1-v/mask.size for k,v in zip(classes, counts)} | ||
| # Set pixel weights | ||
| pdf = np.zeros_like(mask, dtype=np.float32) | ||
| for k, v in sampling_weights.items(): | ||
| pdf[mask==k] = v | ||
| # Set weight and sampling probability for ignored regions to 0 | ||
@@ -338,37 +358,64 @@ if ignore is not None: | ||
| #if igonore_edges: | ||
| # w = int(self.tile_shape[0]*0.25) | ||
| # pdf[:, :w] = pdf[:, -w:] = 0 | ||
| # pdf[:w, :] = pdf[-w:, :] = 0 | ||
| if igonore_edges_pct>0: | ||
| w = int(self.tile_shape[0]*igonore_edges_pct/2) #0.25 | ||
| pdf[:, :w] = pdf[:, -w:] = 0 | ||
| pdf[:w, :] = pdf[-w:, :] = 0 | ||
| # Reshape | ||
| reshape_w = int((pdf.shape[1]/pdf.shape[0])*self.pdf_reshape) | ||
| pdf = cv2.resize(pdf, dsize=(reshape_w, self.pdf_reshape)) | ||
| # Normalize pixel weights | ||
| pdf /= pdf.sum() | ||
| return np.cumsum(pdf/np.sum(pdf)) | ||
| def _preproc_file(self, file): | ||
| "Preprocesses and saves labels (msk), weights, and pdf." | ||
| label_path = self.label_fn(file) | ||
| ign = self.ignore[file.name] if file.name in self.ignore else None | ||
| lbl = self.read_mask(label_path, n_classes=self.c, instance_labels=self.instance_labels, remove_connectivity=self.remove_connectivity) | ||
| self.labels[file.name] = lbl | ||
| self.pdfs[file.name] = self._create_cdf(lbl, ignore=ign) | ||
| def _preproc_file(self, file, use_zarr_data=True): | ||
| "Preprocesses and saves images, labels (msk), weights, and pdf." | ||
| def _preproc(self, verbose=0): | ||
| # Load and save image | ||
| img = self.read_img(file) | ||
| if self.stats is None: | ||
| self.mean_sum += img.mean((0,1)) | ||
| self.var_sum += img.var((0,1)) | ||
| self.max_tile_count = max(self.max_tile_count, tiles_in_rectangles(*img.shape[:2], *self.actual_tile_shape)) | ||
| if use_zarr_data: self.data[file.name] = img | ||
| if self.label_fn is not None: | ||
| # Load and save image | ||
| label_path = self.label_fn(file) | ||
| ign = self.ignore[file.name] if file.name in self.ignore else None | ||
| lbl = self.read_mask(label_path, num_classes=self.c, instance_labels=self.instance_labels, remove_connectivity=self.remove_connectivity) | ||
| self.labels[file.name] = lbl | ||
| self.pdfs[file.name] = self._create_cdf(lbl, ignore=ign) | ||
| def _preproc(self, use_zarr_data=True, verbose=0): | ||
| using_cache = False | ||
| for f in self.files: | ||
| if verbose>0: print('Preprocessing data') | ||
| for f in progress_bar(self.files, leave=True if verbose>0 else False): | ||
| if self.use_preprocessed_labels: | ||
| try: | ||
| self.labels[f.name] | ||
| self.pdfs[f.name] | ||
| if not using_cache: | ||
| if verbose>0: print(f'Using preprocessed masks from {self.preproc_dir}') | ||
| using_cache = True | ||
| self.data[f.name] | ||
| if self.label_fn is not None: | ||
| self.labels[f.name] | ||
| self.pdfs[f.name] | ||
| if not using_cache: | ||
| #if verbose>0: print(f'Using preprocessed data from {self.preproc_dir}') | ||
| using_cache = True | ||
| except: | ||
| if verbose>0: print('Preprocessing', f.name) | ||
| self._preproc_file(f) | ||
| self._preproc_file(f, use_zarr_data=use_zarr_data) | ||
| else: | ||
| if verbose>0: print('Preprocessing', f.name) | ||
| self._preproc_file(f) | ||
| self._preproc_file(f, use_zarr_data=use_zarr_data) | ||
| self.use_zarr_data=use_zarr_data | ||
| if self.stats is None: | ||
| n = len(self.files) | ||
| #https://stackoverflow.com/questions/60101240/finding-mean-and-standard-deviation-across-image-channels-pytorch/60803379#60803379 | ||
| self.stats = {'channel_means': self.mean_sum/n, | ||
| 'channel_stds': np.sqrt(self.var_sum/n), | ||
| 'max_tiles_per_image': self.max_tile_count} | ||
| print('Calculated stats', self.stats) | ||
| def get_data(self, files=None, max_n=None, mask=False): | ||
@@ -401,22 +448,6 @@ if files is not None: | ||
| lbl = self.labels[f.name] | ||
| show(img, lbl, file_name=f.name, figsize=figsize, show_bbox=False, **kwargs) | ||
| show(img, lbl, file_name=f.name, figsize=figsize, show_bbox=False, num_classes=self.num_classes, **kwargs) | ||
| else: | ||
| show(img, file_name=f.name, figsize=figsize, show_bbox=False, **kwargs) | ||
| #https://stackoverflow.com/questions/60101240/finding-mean-and-standard-deviation-across-image-channels-pytorch/60803379#60803379 | ||
| def compute_stats(self, max_samples=50): | ||
| "Computes mean and std from files" | ||
| print('Computing Stats...') | ||
| mean_sum, var_sum = 0., 0. | ||
| for i, f in enumerate(self.files, 1): | ||
| img = self.read_img(f)[:] | ||
| mean_sum += img.mean((0,1)) | ||
| var_sum += img.var((0,1)) | ||
| if i==max_samples: | ||
| print(f'Calculated stats from {i} files') | ||
| continue | ||
| self.mean = mean_sum/i | ||
| self.std = np.sqrt(var_sum/i) | ||
| return self.mean, self.std#*2 | ||
| # Cell | ||
@@ -435,7 +466,6 @@ class RandomTileDataset(BaseDataset): | ||
| if self.sample_mult is None: | ||
| #tile_shape = np.array(self.tile_shape)-np.array(self.padding) | ||
| #msk_shape = np.array(self.get_data(max_n=1)[0].shape[:-1]) | ||
| #msk_shape = np.array(lbl.shape[-2:]) | ||
| #sample_mult = int(np.product(np.floor(msk_shape/tile_shape))) | ||
| self.sample_mult = max(1, min_length//len(self.files)) | ||
| self.sample_mult = max(int(self.stats['max_tiles_per_image']/self.scale**2), | ||
| min_length//len(self.files)) | ||
@@ -445,3 +475,5 @@ tfms = self.albumentations_tfms | ||
| tfms += [ | ||
| A.Normalize(mean=self.stats[0], std=self.stats[1], max_pixel_value=1.) | ||
| A.Normalize(mean=self.stats['channel_means'], | ||
| std=self.stats['channel_stds'], | ||
| max_pixel_value=1.0) | ||
| ] | ||
@@ -491,5 +523,5 @@ self.tfms = A.Compose(tfms+[ToTensorV2()]) | ||
| n_inp = 1 | ||
| def __init__(self, *args, val_length=None, val_seed=42, is_zarr=False, shift=1., border_padding_factor=0.25, return_index=False, **kwargs): | ||
| def __init__(self, *args, val_length=None, val_seed=42, max_tile_shift=1., border_padding_factor=0.25, return_index=False, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self.shift = shift | ||
| self.max_tile_shift = max_tile_shift | ||
| self.bpf = border_padding_factor | ||
@@ -509,13 +541,8 @@ self.return_index = return_index | ||
| tfms += [ | ||
| #A.ToFloat(), | ||
| A.Normalize(mean=self.stats[0], std=self.stats[1], max_pixel_value=1.) | ||
| A.Normalize(mean=self.stats['channel_means'], | ||
| std=self.stats['channel_stds'], | ||
| max_pixel_value=1.0) | ||
| ] | ||
| self.tfms = A.Compose(tfms+[ToTensorV2()]) | ||
| if self.files[0].suffix == '.zarr' or is_zarr: | ||
| self.data = zarr.open_group(self.files[0].parent.as_posix(), mode='r') | ||
| is_zarr = True | ||
| else: | ||
| root = zarr.group(store=zarr.storage.TempStore(), overwrite=True) | ||
| self.data = root.create_group('data') | ||
@@ -525,3 +552,2 @@ j = 0 | ||
| img = self.read_img(file) | ||
| if not is_zarr: self.data[file.name] = img | ||
| # Tiling | ||
@@ -531,3 +557,3 @@ data_shape = tuple(int(x//self.scale) for x in img.shape[:-1]) | ||
| end_points = [(s - st) for s, st in zip(data_shape, start_points)] | ||
| n_points = [int((s+2*o*self.bpf)//(o*self.shift))+1 for s, o in zip(data_shape, self.output_shape)] | ||
| n_points = [int((s+2*o*self.bpf)//(o*self.max_tile_shift))+1 for s, o in zip(data_shape, self.output_shape)] | ||
| center_points = [np.linspace(st, e, num=n, endpoint=True, dtype=np.int64) for st, e, n in zip(start_points, end_points, n_points)] | ||
@@ -571,3 +597,4 @@ for cx in center_points[1]: | ||
| img_path = self.files[self.image_indices[idx]] | ||
| img = self.data[img_path.name] | ||
| #img = self.data[img_path.name] | ||
| img = self.read_img(img_path) | ||
| centerPos = self.centers[idx] | ||
@@ -574,0 +601,0 @@ |
+32
-46
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/09_gt.ipynb (unless otherwise specified). | ||
| __all__ = ['import_sitk', 'staple', 'm_voting', 'msk_show', 'GTEstimator'] | ||
| __all__ = ['import_sitk', 'staple_multi_label', 'm_voting', 'GTEstimator'] | ||
@@ -11,7 +11,8 @@ # Cell | ||
| from fastai.data.transforms import get_image_files | ||
| import matplotlib.pyplot as plt | ||
| from .data import _read_msk | ||
| from .learner import Config | ||
| from .utils import save_mask, dice_score, install_package, get_instance_segmentation_metrics | ||
| from .config import Config | ||
| from .utils import clean_show, save_mask, dice_score, install_package, get_instance_segmentation_metrics | ||
@@ -30,10 +31,10 @@ # Cell | ||
| # Cell | ||
| def staple(segmentations, foregroundValue = 1, threshold = 0.5): | ||
| def staple_multi_label(segmentations, label_undecided_pixel=1): | ||
| 'STAPLE: Simultaneous Truth and Performance Level Estimation with simple ITK' | ||
| sitk = import_sitk() | ||
| segmentations = [sitk.GetImageFromArray(x) for x in segmentations] | ||
| STAPLE_probabilities = sitk.STAPLE(segmentations) | ||
| STAPLE = STAPLE_probabilities > threshold | ||
| #STAPLE = sitk.GetArrayViewFromImage(STAPLE) | ||
| return sitk.GetArrayFromImage(STAPLE) | ||
| sitk_segmentations = [sitk.GetImageFromArray(x) for x in segmentations] | ||
| STAPLE = sitk.MultiLabelSTAPLEImageFilter() | ||
| STAPLE.SetLabelForUndecidedPixels(label_undecided_pixel) | ||
| msk = STAPLE.Execute(sitk_segmentations) | ||
| return sitk.GetArrayFromImage(msk) | ||
@@ -49,18 +50,2 @@ # Cell | ||
| # Cell | ||
| from mpl_toolkits.axes_grid1 import make_axes_locatable | ||
| def msk_show(ax, msk, title, cbar=None, ticks=None, **kwargs): | ||
| img = ax.imshow(msk, **kwargs) | ||
| if cbar is not None: | ||
| divider = make_axes_locatable(ax) | ||
| cax = divider.append_axes("right", size="5%", pad=0.05) | ||
| if cbar=='plot': | ||
| scale = ticks/(ticks+1) | ||
| cbr = plt.colorbar(img, cax=cax, ticks=[i*(scale)+(scale/2) for i in range(0, ticks+1)]) | ||
| cbr.set_ticklabels([i for i in range(0, ticks+1)]) | ||
| cbr.set_label('# of experts', rotation=270, labelpad=+15, fontsize="larger") | ||
| else: cax.set_axis_off() | ||
| ax.set_axis_off() | ||
| ax.set_title(title) | ||
| # Cell | ||
| class GTEstimator(GetAttr): | ||
@@ -104,8 +89,9 @@ "Class for ground truth estimation" | ||
| fig, axs = plt.subplots(nrows=1, ncols=len(exps), figsize=figsize, **kwargs) | ||
| vkwargs = {'vmin':0, 'vmax':self.num_classes-1} | ||
| for i, exp in enumerate(exps): | ||
| try: | ||
| msk = _read_msk(self.mask_fn(exp,m), instance_labels=self.instance_labels) | ||
| except: | ||
| raise ValueError('Ground truth estimation currently only suppports two classes (binary masks or instance labels)') | ||
| msk_show(axs[i], msk, exp, cmap=self.cmap) | ||
| msk = _read_msk(self.mask_fn(exp,m), num_classes=self.num_classes, instance_labels=self.instance_labels) | ||
| if i == len(exps) - 1: | ||
| clean_show(axs[i], msk, exp, self.cmap, cbar='classes', ticks=self.num_classes, **vkwargs) | ||
| else: | ||
| clean_show(axs[i], msk, exp, self.cmap, **vkwargs) | ||
| fig.text(0, .5, m, ha='center', va='center', rotation=90) | ||
@@ -121,10 +107,11 @@ plt.tight_layout() | ||
| for m, exps in progress_bar(self.masks.items()): | ||
| masks = [_read_msk(self.mask_fn(exp,m), instance_labels=self.instance_labels) for exp in exps] | ||
| masks = [_read_msk(self.mask_fn(exp,m), num_classes=self.num_classes, instance_labels=self.instance_labels) for exp in exps] | ||
| if method=='STAPLE': | ||
| ref = staple(masks, self.staple_fval, self.staple_thres) | ||
| #ref = staple(masks, self.staple_fval, self.staple_thres) | ||
| ref = staple_multi_label(masks, self.vote_undec) | ||
| elif method=='majority_voting': | ||
| ref = m_voting(masks, self.majority_vote_undec) | ||
| ref = m_voting(masks, self.vote_undec) | ||
| refs[m] = ref | ||
| #assert ref.mean() > 0, 'Please try again!' | ||
| df_tmp = pd.DataFrame({'method': method, 'file' : m, 'exp' : exps, 'dice_score': [dice_score(ref, msk) for msk in masks]}) | ||
| df_tmp = pd.DataFrame({'method': method, 'file' : m, 'exp' : exps, 'dice_score': [dice_score(ref, msk, num_classes=self.num_classes) for msk in masks]}) | ||
| if self.instance_segmentation_metrics: | ||
@@ -164,9 +151,13 @@ mAP, AP = [],[] | ||
| for f in files: | ||
| fig, ax = plt.subplots(ncols=2, figsize=figsize, **kwargs) | ||
| # GT | ||
| msk_show(ax[0], self.gt[method][f], f'{method} (binary mask)', cbar='', cmap=self.cmap) | ||
| # Experts | ||
| masks = [_read_msk(self.mask_fn(exp,f), instance_labels=self.instance_labels) for exp in self.masks[f]] | ||
| masks_av = np.array(masks).sum(axis=0)#/len(masks) | ||
| msk_show(ax[1], masks_av, 'Expert Overlay', cbar='plot', ticks=len(masks), cmap=plt.cm.get_cmap(self.cmap, len(masks)+1)) | ||
| if self.num_classes==2: | ||
| fig, ax = plt.subplots(ncols=2, figsize=figsize, **kwargs) | ||
| # GT | ||
| clean_show(ax[0], self.gt[method][f], f'{method} (binary mask)', cbar='', cmap=self.cmap) | ||
| # Experts | ||
| masks = [_read_msk(self.mask_fn(exp,f), num_classes=self.num_classes, instance_labels=self.instance_labels) for exp in self.masks[f]] | ||
| masks_av = np.array(masks).sum(axis=0)#/len(masks) | ||
| clean_show(ax[1], masks_av, 'Expert Overlay', cbar='experts', ticks=len(masks), cmap=plt.cm.get_cmap(self.cmap, len(masks)+1)) | ||
| else: | ||
| fig, ax = plt.subplots(ncols=1, figsize=figsize, **kwargs) | ||
| clean_show(ax, self.gt[method][f], f'{method}', cbar='classes', cmap=self.cmap, ticks=self.num_classes) | ||
| # Results | ||
@@ -176,7 +167,2 @@ metrics = ['dice_score', 'mean_average_precision', 'average_precision_at_iou_50'] if self.instance_segmentation_metrics else ['dice_score'] | ||
| plt_df = self.df_res[self.df_res.file==f].set_index('exp')[metrics].append(av_df) | ||
| #plt_df.columns = [f'Similarity (Dice Score)'] | ||
| #tbl = pd.plotting.table(ax[2], np.round(plt_df,3), loc='center', colWidths=[.5]) | ||
| #tbl.set_fontsize(14) | ||
| #tbl.scale(1, 2) | ||
| #ax[2].set_axis_off() | ||
| fig.text(0, .5, f, ha='center', va='center', rotation=90) | ||
@@ -183,0 +169,0 @@ plt.tight_layout() |
+301
-479
@@ -1,27 +0,23 @@ | ||
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_learner.ipynb (unless otherwise specified). | ||
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03_learner.ipynb (unless otherwise specified). | ||
| __all__ = ['Config', 'energy_score', 'EnsemblePredict', 'EnsembleLearner'] | ||
| __all__ = ['EnsembleBase', 'EnsembleLearner', 'EnsemblePredictor'] | ||
| # Cell | ||
| import shutil, gc, joblib, json, zarr, numpy as np, pandas as pd | ||
| import torch | ||
| import time | ||
| import tifffile, cv2 | ||
| import torch, torch.nn as nn, torch.nn.functional as F | ||
| from torch.utils.data import DataLoader | ||
| from dataclasses import dataclass, field, asdict | ||
| import zarr | ||
| import pandas as pd | ||
| import numpy as np | ||
| import cv2 | ||
| import tifffile | ||
| from pathlib import Path | ||
| from typing import List, Union, Tuple | ||
| from sklearn import svm | ||
| from skimage.color import label2rgb | ||
| from sklearn.model_selection import KFold | ||
| from sklearn.pipeline import Pipeline | ||
| from sklearn.preprocessing import StandardScaler | ||
| from scipy.ndimage.filters import gaussian_filter | ||
| from skimage.color import label2rgb | ||
| import matplotlib.pyplot as plt | ||
| from fastprogress import progress_bar | ||
| from fastcore.basics import patch, GetAttr | ||
| from fastcore.foundation import add_docs, L | ||
| from fastcore.basics import GetAttr | ||
| from fastcore.foundation import L | ||
| from fastai import optimizer | ||
| from fastai.torch_core import TensorImage | ||
| from fastai.learner import Learner | ||
@@ -32,120 +28,19 @@ from fastai.callback.tracker import SaveModelCallback | ||
| from fastai.data.transforms import get_image_files, get_files | ||
| from fastai.vision.augment import Brightness, Contrast, Saturation | ||
| from fastai.losses import CrossEntropyLossFlat | ||
| from fastai.metrics import Dice, DiceMulti | ||
| from .config import Config | ||
| from .data import BaseDataset, TileDataset, RandomTileDataset | ||
| from .models import create_smp_model, save_smp_model, load_smp_model, run_cellpose | ||
| from .inference import InferenceEnsemble | ||
| from .losses import get_loss | ||
| from .models import create_smp_model, save_smp_model, load_smp_model, run_cellpose | ||
| from .data import TileDataset, RandomTileDataset, _read_img, _read_msk | ||
| from .utils import dice_score, plot_results, get_label_fn, calc_iterations, save_mask, save_unc, export_roi_set, get_instance_segmentation_metrics | ||
| from .utils import compose_albumentations as _compose_albumentations | ||
| import deepflash2.tta as tta | ||
| from .utils import dice_score, binary_dice_score, plot_results, get_label_fn, save_mask, save_unc, export_roi_set, get_instance_segmentation_metrics | ||
| from fastai.metrics import Dice, DiceMulti | ||
| # Cell | ||
| @dataclass | ||
| class Config: | ||
| "Config class for settings." | ||
| import matplotlib.pyplot as plt | ||
| import warnings | ||
| # Project | ||
| project_dir:str = '.' | ||
| #https://discuss.pytorch.org/t/slow-forward-on-traced-graph-on-cuda-2nd-iteration/118445/7 | ||
| try: torch._C._jit_set_fusion_strategy([('STATIC', 0)]) | ||
| except: torch._C._jit_set_bailout_depth(0) | ||
| # GT Estimation Settings | ||
| staple_thres:float = 0.5 | ||
| staple_fval:int= 1 | ||
| majority_vote_undec:int = 1 | ||
| # Train General Settings | ||
| n_models:int = 5 | ||
| max_splits:int=5 | ||
| random_state:int = 42 | ||
| # Pytorch Segmentation Model Settings | ||
| arch:str = 'Unet' | ||
| encoder_name:str = 'resnet34' | ||
| encoder_weights:str = 'imagenet' | ||
| # Train Data Settings | ||
| n_classes:int = 2 | ||
| tile_shape:int = 512 | ||
| instance_labels:bool = False | ||
| # Train Settings | ||
| base_lr:float = 0.001 | ||
| batch_size:int = 4 | ||
| weight_decay:float = 0.001 | ||
| mixed_precision_training:bool = False | ||
| optim:str = 'Adam' | ||
| loss:str = 'CrossEntropyDiceLoss' | ||
| n_iter:int = 2500 | ||
| sample_mult:int = 0 | ||
| # Validation and Prediction Settings | ||
| tta:bool = True | ||
| border_padding_factor:float = 0.25 | ||
| shift:float = 0.5 | ||
| # Train Data Augmentation | ||
| gamma_limit_lower:int = 80 | ||
| gamma_limit_upper:int = 120 | ||
| CLAHE_clip_limit:float = 0.0 | ||
| brightness_limit:float = 0.0 | ||
| contrast_limit:float = 0.0 | ||
| flip:bool = True | ||
| rot:int = 360 | ||
| distort_limit:float = 0 | ||
| # Loss Settings | ||
| mode:str = 'multiclass' #currently only tested for multiclass | ||
| loss_alpha:float = 0.5 # Twerksky/Focal loss | ||
| loss_beta:float = 0.5 # Twerksy Loss | ||
| loss_gamma:float = 2.0 # Focal loss | ||
| loss_smooth_factor:float = 0. #SoftCrossEntropyLoss | ||
| # Pred Settings | ||
| pred_tta:bool = True | ||
| min_pixel_export:int = 0 | ||
| # Instance Segmentation Settings | ||
| cellpose_model:str='nuclei' | ||
| cellpose_diameter:int=0 | ||
| cellpose_export_class:int=1 | ||
| instance_segmentation_metrics:bool=False | ||
| # Folder Structure | ||
| gt_dir:str = 'GT_Estimation' | ||
| train_dir:str = 'Training' | ||
| pred_dir:str = 'Prediction' | ||
| ens_dir:str = 'models' | ||
| val_dir:str = 'valid' | ||
| @property | ||
| def albumentation_kwargs(self): | ||
| kwargs = ['gamma_limit_lower', 'gamma_limit_upper', 'CLAHE_clip_limit', | ||
| 'brightness_limit', 'contrast_limit', 'distort_limit'] | ||
| return dict(filter(lambda x: x[0] in kwargs, self.__dict__.items())) | ||
| @property | ||
| def svm_kwargs(self): | ||
| svm_vars = ['kernel', 'nu', 'gamma'] | ||
| return dict(filter(lambda x: x[0] in svm_vars, self.__dict__.items())) | ||
| def save(self, path): | ||
| 'Save configuration to path' | ||
| path = Path(path).with_suffix('.json') | ||
| with open(path, 'w') as config_file: | ||
| json.dump(asdict(self), config_file) | ||
| print(f'Saved current configuration to {path}.json') | ||
| return path | ||
| def load(self, path): | ||
| 'Load configuration from path' | ||
| path = Path(path) | ||
| try: | ||
| with open(path) as config_file: c = json.load(config_file) | ||
| if not Path(c['project_dir']).is_dir(): c['project_dir']='deepflash2' | ||
| for k,v in c.items(): setattr(self, k, v) | ||
| print(f'Successsfully loaded configuration from {path}') | ||
| except: | ||
| print('Error! Select valid config file (.json)') | ||
| # Cell | ||
@@ -164,170 +59,71 @@ _optim_dict = { | ||
| # Cell | ||
| # from https://github.com/MIC-DKFZ/nnUNet/blob/2fade8f32607220f8598544f0d5b5e5fa73768e5/nnunet/network_architecture/neural_network.py#L250 | ||
| def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray: | ||
| tmp = np.zeros(patch_size) | ||
| center_coords = [i // 2 for i in patch_size] | ||
| sigmas = [i * sigma_scale for i in patch_size] | ||
| tmp[tuple(center_coords)] = 1 | ||
| gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) | ||
| gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1 | ||
| gaussian_importance_map = gaussian_importance_map.astype(np.float32) | ||
| class EnsembleBase(GetAttr): | ||
| _default = 'config' | ||
| def __init__(self, image_dir:str=None, mask_dir:str=None, files:List[Path]=None, label_fn:callable=None, | ||
| config:Config=None, path:Path=None, zarr_store:str=None): | ||
| # gaussian_importance_map cannot be 0, otherwise we may end up with nans! | ||
| gaussian_importance_map[gaussian_importance_map == 0] = np.min( | ||
| gaussian_importance_map[gaussian_importance_map != 0]) | ||
| self.config = config or Config() | ||
| self.path = Path(path) if path is not None else Path('.') | ||
| self.label_fn = None | ||
| self.files = L() | ||
| return gaussian_importance_map | ||
| store = str(zarr_store) if zarr_store else zarr.storage.TempStore() | ||
| root = zarr.group(store=store, overwrite=False) | ||
| self.store = root.chunk_store.path | ||
| self.g_pred, self.g_smx, self.g_std = root.require_groups('preds', 'smxs', 'stds') | ||
| # Cell | ||
| def energy_score(x, T=1, dim=1): | ||
| 'Return the energy score as proposed by Liu, Weitang, et al. (2020).' | ||
| return -(T*torch.logsumexp(x/T, dim=dim)) | ||
| if any(v is not None for v in (image_dir, files)): | ||
| self.files = L(files) or self.get_images(image_dir) | ||
| # Cell | ||
| class EnsemblePredict(): | ||
| 'Class for prediction with multiple models' | ||
| def __init__(self, models_paths, zarr_store=None): | ||
| self.models_paths = models_paths | ||
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
| self.init_models() | ||
| if any(v is not None for v in (mask_dir, label_fn)): | ||
| assert hasattr(self, 'files'), 'image_dir or files must be provided' | ||
| self.label_fn = label_fn or self.get_label_fn(mask_dir) | ||
| self.check_label_fn() | ||
| # Init zarr storage | ||
| self.store = str(zarr_store) if zarr_store else zarr.storage.TempStore() | ||
| self.root = zarr.group(store=self.store) | ||
| self.g_smx = self.root.require_group('smx') | ||
| self.g_eng, self.g_std = None, None | ||
| def get_images(self, img_dir:str='images', img_path:Path=None) -> List[Path]: | ||
| 'Returns list of image paths' | ||
| path = img_path or self.path/img_dir | ||
| files = get_image_files(path, recurse=False) | ||
| print(f'Found {len(files)} images in "{path}".') | ||
| if len(files)==0: warnings.warn('Please check your provided images and image folder') | ||
| return files | ||
| def init_models(self): | ||
| self.models = [] | ||
| self.stats = None | ||
| for p in self.models_paths: | ||
| model, stats = load_smp_model(p) | ||
| if not self.stats: self.stats = stats | ||
| assert np.array_equal(stats, self.stats), 'Only models trained on the same stats are allowed.' | ||
| model.float() | ||
| model.eval() | ||
| model.to(self.device) | ||
| self.models.append(model) | ||
| def get_label_fn(self, msk_dir:str='masks', msk_path:Path=None): | ||
| 'Returns label function to get paths of masks' | ||
| path = msk_path or self.path/msk_dir | ||
| return get_label_fn(self.files[0], path) | ||
| def predict(self, | ||
| ds, | ||
| use_tta=True, | ||
| bs=4, | ||
| use_gaussian=True, | ||
| sigma_scale=1./8, | ||
| uncertainty_estimates=True, | ||
| uncertainty_type = 'uncertainty', | ||
| energy_scores=False, | ||
| energy_T = 1., | ||
| verbose=0): | ||
| def check_label_fn(self): | ||
| 'Checks label function' | ||
| mask_check = [self.label_fn(x).exists() for x in self.files] | ||
| chk_str = f'Found {sum(mask_check)} corresponding masks.' | ||
| print(chk_str) | ||
| if len(self.files)!=sum(mask_check): | ||
| warnings.warn(f'Please check your images and masks (and folders).') | ||
| if verbose>0: print('Ensemble prediction with models:', self.models_paths) | ||
| def predict(self, arr:Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
| 'Get prediction for arr using inference_ensemble' | ||
| inp = torch.tensor(arr).float().to(self.device) | ||
| with torch.inference_mode(): | ||
| preds = self.inference_ensemble(inp) | ||
| preds = [x.cpu().numpy() for x in preds] | ||
| return tuple(preds) | ||
| tfms = [tta.HorizontalFlip(),tta.VerticalFlip()] if use_tta else [] | ||
| if verbose>0: print('Using Test-Time Augmentation with:', tfms) | ||
| def save_preds_zarr(self, f_name, pred, smx, std): | ||
| self.g_pred[f_name] = pred | ||
| self.g_smx[f_name] = smx | ||
| self.g_std[f_name] = std | ||
| dl = DataLoader(ds, bs, num_workers=0, shuffle=False, pin_memory=True) | ||
| def _create_ds(self, **kwargs): | ||
| self.ds = BaseDataset(self.files, label_fn=self.label_fn, instance_labels=self.instance_labels, | ||
| num_classes=self.num_classes, **kwargs) | ||
| # Create zero arrays | ||
| data_shape = ds.image_shapes[0] | ||
| softmax = np.zeros((*data_shape, ds.c), dtype='float32') | ||
| merge_map = np.zeros(data_shape, dtype='float32') | ||
| stdeviation = np.zeros(data_shape, dtype='float32') if uncertainty_estimates else None | ||
| energy = np.zeros(data_shape, dtype='float32') if energy_scores else None | ||
| # Define merge weights | ||
| if use_gaussian: | ||
| mw_numpy = _get_gaussian(ds.output_shape, sigma_scale) | ||
| else: | ||
| mw_numpy = np.ones(dl.output_shape) | ||
| mw = torch.from_numpy(mw_numpy).to(self.device) | ||
| # Loop over tiles (indices required!) | ||
| for tiles, idxs in iter(dl): | ||
| tiles = tiles.to(self.device) | ||
| smx_merger = tta.Merger() | ||
| if energy_scores: | ||
| energy_merger = tta.Merger() | ||
| # Loop over tt-augmentations | ||
| for t in tta.Compose(tfms): | ||
| aug_tiles = t.augment_image(tiles) | ||
| model_merger = tta.Merger() | ||
| if energy_scores: engergy_list = [] | ||
| # Loop over models | ||
| for model in self.models: | ||
| with torch.inference_mode(): | ||
| logits = model(aug_tiles) | ||
| logits = t.deaugment_mask(logits) | ||
| smx_merger.append(F.softmax(logits, dim=1)) | ||
| if energy_scores: | ||
| energy_merger.append(-energy_score(logits, energy_T)) #negative energy score | ||
| out_list = [] | ||
| # Apply gaussian weigthing | ||
| batch_smx = smx_merger.result()*mw.view(1,1,*mw.shape) | ||
| # Reshape and append to list | ||
| out_list.append([x for x in batch_smx.permute(0,2,3,1).cpu().numpy()]) | ||
| if uncertainty_estimates: | ||
| batch_std = torch.mean(smx_merger.result(uncertainty_type), dim=1)*mw.view(1,*mw.shape) | ||
| out_list.append([x for x in batch_std.cpu().numpy()]) | ||
| if energy_scores: | ||
| batch_energy = energy_merger.result()*mw.view(1,*mw.shape) | ||
| out_list.append([x for x in batch_energy.cpu().numpy()]) | ||
| # Compose predictions | ||
| for preds in zip(*out_list, idxs): | ||
| if len(preds)==4: smx,std,eng,idx = preds | ||
| elif uncertainty_estimates: smx,std,idx = preds | ||
| elif energy_scores: smx,eng,idx = preds | ||
| else: smx, idx = preds | ||
| out_slice = ds.out_slices[idx] | ||
| in_slice = ds.in_slices[idx] | ||
| softmax[out_slice] += smx[in_slice] | ||
| merge_map[out_slice] += mw_numpy[in_slice] | ||
| if uncertainty_estimates: | ||
| stdeviation[out_slice] += std[in_slice] | ||
| if energy_scores: | ||
| energy[out_slice] += eng[in_slice] | ||
| # Normalize weighting | ||
| softmax /= merge_map[..., np.newaxis] | ||
| if uncertainty_estimates: | ||
| stdeviation /= merge_map | ||
| if energy_scores: | ||
| energy /= merge_map | ||
| return softmax, stdeviation, energy | ||
| def predict_images(self, image_list, ds_kwargs={}, verbose=1, **kwargs): | ||
| "Predict images in 'image_list' with kwargs and save to zarr" | ||
| for f in progress_bar(image_list, leave=False): | ||
| if verbose>0: print(f'Predicting {f.name}') | ||
| ds = TileDataset([f], stats=self.stats, return_index=True, **ds_kwargs) | ||
| softmax, stdeviation, energy = self.predict(ds, **kwargs) | ||
| # Save to zarr | ||
| self.g_smx[f.name] = softmax | ||
| if stdeviation is not None: | ||
| self.g_std = self.root.require_group('std') | ||
| self.g_std[f.name] = stdeviation | ||
| if energy is not None: | ||
| self.g_eng = self.root.require_group('energy') | ||
| self.g_eng[f.name] = energy | ||
| return self.g_smx, self.g_std, self.g_eng | ||
| # Cell | ||
| class EnsembleLearner(GetAttr): | ||
| _default = 'config' | ||
| def __init__(self, image_dir='images', mask_dir=None, config=None, path=None, ensemble_path=None, preproc_dir=None, | ||
| label_fn=None, metrics=None, cbs=None, ds_kwargs={}, dl_kwargs={}, model_kwargs={}, stats=None, files=None): | ||
| class EnsembleLearner(EnsembleBase): | ||
| "Meta class to training model ensembles with `n` models" | ||
| def __init__(self, *args, ensemble_path=None, preproc_dir=None, metrics=None, cbs=None, | ||
| ds_kwargs={}, dl_kwargs={}, model_kwargs={}, stats=None, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self.config = config or Config() | ||
| assert hasattr(self, 'label_fn'), 'mask_dir or label_fn must be provided.' | ||
| self.stats = stats | ||
@@ -337,38 +133,15 @@ self.dl_kwargs = dl_kwargs | ||
| self.add_ds_kwargs = ds_kwargs | ||
| self.path = Path(path) if path is not None else Path('.') | ||
| default_metrics = [Dice()] if self.n_classes==2 else [DiceMulti()] | ||
| default_metrics = [Dice()] if self.num_classes==2 else [DiceMulti()] | ||
| self.metrics = metrics or default_metrics | ||
| self.loss_fn = self.get_loss() | ||
| self.cbs = cbs or [SaveModelCallback(monitor='dice' if self.n_classes==2 else 'dice_multi')] #ShowGraphCallback | ||
| self.cbs = cbs or [SaveModelCallback(monitor='dice' if self.num_classes==2 else 'dice_multi')] #ShowGraphCallback | ||
| self.ensemble_dir = ensemble_path or self.path/self.ens_dir | ||
| if ensemble_path is not None: | ||
| ensemble_path.mkdir(exist_ok=True, parents=True) | ||
| self.load_ensemble(path=ensemble_path) | ||
| self.load_models(path=ensemble_path) | ||
| else: self.models = {} | ||
| self.files = L(files) or get_image_files(self.path/image_dir, recurse=False) | ||
| assert len(self.files)>0, f'Found {len(self.files)} images in "{image_dir}". Please check your images and image folder' | ||
| if any([mask_dir, label_fn]): | ||
| if label_fn: self.label_fn = label_fn | ||
| else: self.label_fn = get_label_fn(self.files[0], self.path/mask_dir) | ||
| #Check if corresponding masks exist | ||
| mask_check = [self.label_fn(x).exists() for x in self.files] | ||
| chk_str = f'Found {len(self.files)} images in "{image_dir}" and {sum(mask_check)} masks in "{mask_dir}".' | ||
| assert len(self.files)==sum(mask_check) and len(self.files)>0, f'Please check your images and masks (and folders). {chk_str}' | ||
| print(chk_str) | ||
| else: | ||
| self.label_fn = label_fn | ||
| self.n_splits=min(len(self.files), self.max_splits) | ||
| self._set_splits() | ||
| self.ds = RandomTileDataset(self.files, label_fn=self.label_fn, | ||
| preproc_dir=preproc_dir, | ||
| instance_labels=self.instance_labels, | ||
| n_classes=self.n_classes, | ||
| stats=self.stats, | ||
| normalize = True, | ||
| sample_mult=self.sample_mult if self.sample_mult>0 else None, | ||
| verbose=0, | ||
| **self.add_ds_kwargs) | ||
| self._create_ds(stats=self.stats, preproc_dir=preproc_dir, verbose=1, **self.add_ds_kwargs) | ||
| self.stats = self.ds.stats | ||
@@ -397,4 +170,5 @@ self.in_channels = self.ds.get_data(max_n=1)[0].shape[-1] | ||
| ds_kwargs['tile_shape']= (self.tile_shape,)*2 | ||
| ds_kwargs['n_classes']= self.n_classes | ||
| ds_kwargs['shift']= self.shift | ||
| ds_kwargs['num_classes']= self.num_classes | ||
| ds_kwargs['max_tile_shift']= self.max_tile_shift | ||
| ds_kwargs['scale']= self.scale | ||
| ds_kwargs['border_padding_factor']= self.border_padding_factor | ||
@@ -413,6 +187,8 @@ return ds_kwargs | ||
| ds_kwargs['tile_shape']= (self.tile_shape,)*2 | ||
| ds_kwargs['n_classes']= self.n_classes | ||
| ds_kwargs['shift']= 1. | ||
| ds_kwargs['num_classes']= self.num_classes | ||
| ds_kwargs['scale']= self.scale | ||
| ds_kwargs['flip'] = self.flip | ||
| ds_kwargs['max_tile_shift']= 1. | ||
| ds_kwargs['border_padding_factor']= 0. | ||
| ds_kwargs['flip'] = self.flip | ||
| ds_kwargs['scale']= self.scale | ||
| ds_kwargs['albumentations_tfms'] = self._compose_albumentations(**self.albumentation_kwargs) | ||
@@ -424,7 +200,8 @@ ds_kwargs['sample_mult'] = self.sample_mult if self.sample_mult>0 else None | ||
| def model_name(self): | ||
| return f'{self.arch}_{self.encoder_name}_{self.n_classes}classes' | ||
| encoder_name = self.encoder_name.replace('_', '-') | ||
| return f'{self.arch}_{encoder_name}_{self.num_classes}classes' | ||
| def get_loss(self): | ||
| kwargs = {'mode':self.mode, | ||
| 'classes':[x for x in range(1, self.n_classes)], | ||
| 'classes':[x for x in range(1, self.num_classes)], | ||
| 'smooth_factor': self.loss_smooth_factor, | ||
@@ -439,9 +216,8 @@ 'alpha':self.loss_alpha, | ||
| ds = [] | ||
| ds.append(RandomTileDataset(files, label_fn=self.label_fn, **self.train_ds_kwargs)) | ||
| ds.append(RandomTileDataset(files, label_fn=self.label_fn, **self.train_ds_kwargs, verbose=0)) | ||
| if files_val: | ||
| ds.append(TileDataset(files_val, label_fn=self.label_fn, **self.train_ds_kwargs)) | ||
| ds.append(TileDataset(files_val, label_fn=self.label_fn, **self.train_ds_kwargs, verbose=0)) | ||
| else: | ||
| ds.append(ds[0]) | ||
| dls = DataLoaders.from_dsets(*ds, bs=self.batch_size, pin_memory=True, **self.dl_kwargs) | ||
| if torch.cuda.is_available(): dls.cuda() | ||
| dls = DataLoaders.from_dsets(*ds, bs=self.batch_size, pin_memory=True, **self.dl_kwargs).to(self.device) | ||
| return dls | ||
@@ -454,11 +230,12 @@ | ||
| in_channels=self.in_channels, | ||
| classes=self.n_classes, | ||
| **self.model_kwargs) | ||
| if torch.cuda.is_available(): model.cuda() | ||
| classes=self.num_classes, | ||
| **self.model_kwargs).to(self.device) | ||
| return model | ||
| def fit(self, i, n_iter=None, base_lr=None, **kwargs): | ||
| n_iter = n_iter or self.n_iter | ||
| def fit(self, i, n_epochs=None, base_lr=None, **kwargs): | ||
| 'Fit model number `i`' | ||
| n_epochs = n_epochs or self.n_epochs | ||
| base_lr = base_lr or self.base_lr | ||
| name = self.ensemble_dir/f'{self.model_name}-fold{i}.pth' | ||
| name = self.ensemble_dir/'single_models'/f'{self.model_name}-fold{i}.pth' | ||
| model = self._create_model() | ||
@@ -480,5 +257,3 @@ files_train, files_val = self.splits[i] | ||
| print(f'Starting training for {name.name}') | ||
| epochs = calc_iterations(n_iter=n_iter,ds_length=len(dls.train_ds), bs=self.batch_size) | ||
| #self.learn.fit_one_cycle(epochs, lr_max) | ||
| self.learn.fine_tune(epochs, base_lr=base_lr) | ||
| self.learn.fine_tune(n_epochs, base_lr=base_lr) | ||
@@ -491,8 +266,34 @@ print(f'Saving model at {name}') | ||
| def fit_ensemble(self, n_iter, skip=False, **kwargs): | ||
| del model | ||
| if torch.cuda.is_available(): torch.cuda.empty_cache() | ||
| def get_inference_ensemble(self, model_path=None): | ||
| model_paths = [model_path] if model_path is not None else self.models.values() | ||
| models = [load_smp_model(p)[0] for p in model_paths] | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore") | ||
| ensemble = InferenceEnsemble(models, | ||
| num_classes=self.num_classes, | ||
| in_channels=self.in_channels, | ||
| channel_means=self.stats['channel_means'].tolist(), | ||
| channel_stds=self.stats['channel_stds'].tolist(), | ||
| tile_shape=(self.tile_shape,)*2, | ||
| **self.inference_kwargs).to(self.device) | ||
| return torch.jit.script(ensemble) | ||
| def save_inference_ensemble(self): | ||
| ensemble = self.get_inference_ensemble() | ||
| ensemble_name = self.ensemble_dir/f'ensemble_{self.model_name}.pt' | ||
| print(f'Saving model at {ensemble_name}') | ||
| ensemble.save(ensemble_name) | ||
| def fit_ensemble(self, n_epochs=None, skip=False, save_inference_ensemble=True, **kwargs): | ||
| 'Fit `i` models and `skip` existing' | ||
| for i in range(1, self.n_models+1): | ||
| if skip and (i in self.models): continue | ||
| self.fit(i, n_iter, **kwargs) | ||
| self.fit(i, n_epochs, **kwargs) | ||
| if save_inference_ensemble: self.save_inference_ensemble() | ||
| def set_n(self, n): | ||
| "Change to `n` models per ensemble" | ||
| for i in range(n, len(self.models)): | ||
@@ -503,4 +304,7 @@ self.models.pop(i+1, None) | ||
| def get_valid_results(self, model_no=None, zarr_store=None, export_dir=None, filetype='.png', **kwargs): | ||
| "Validate models on validation data and save results" | ||
| res_list = [] | ||
| model_list = self.models if not model_no else {k:v for k,v in self.models.items() if k==model_no} | ||
| model_dict = self.models if not model_no else {k:v for k,v in self.models.items() if k==model_no} | ||
| metric_name = 'dice_score' if self.num_classes==2 else 'average_dice_score' | ||
| if export_dir: | ||
@@ -513,33 +317,31 @@ export_dir = Path(export_dir) | ||
| for i, model_path in model_list.items(): | ||
| ep = EnsemblePredict(models_paths=[model_path], zarr_store=zarr_store) | ||
| for i, model_path in model_dict.items(): | ||
| print(f'Validating model {i}.') | ||
| self.inference_ensemble = self.get_inference_ensemble(model_path=model_path) | ||
| _, files_val = self.splits[i] | ||
| g_smx, g_std, g_eng = ep.predict_images(files_val, bs=self.batch_size, ds_kwargs=self.pred_ds_kwargs, **kwargs) | ||
| del ep | ||
| torch.cuda.empty_cache() | ||
| chunk_store = g_smx.chunk_store.path | ||
| for j, f in enumerate(files_val): | ||
| msk = self.ds.get_data(f, mask=True)[0] | ||
| pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8') | ||
| m_dice = dice_score(msk, pred) | ||
| m_path = self.models[i].name | ||
| for j, f in progress_bar(enumerate(files_val), total=len(files_val)): | ||
| pred, smx, std = self.predict(self.ds.data[f.name][:]) | ||
| self.save_preds_zarr(f.name, pred, smx, std) | ||
| msk = self.ds.labels[f.name][:] #.get_data(f, mask=True)[0]) | ||
| m_dice = dice_score(msk, pred, num_classes=self.num_classes) | ||
| df_tmp = pd.Series({'file' : f.name, | ||
| 'model' : m_path, | ||
| 'model' : model_path, | ||
| 'model_no' : i, | ||
| 'dice_score': m_dice, | ||
| #'mean_energy': np.mean(g_eng[f.name][:][pred>0]), | ||
| 'uncertainty_score': np.mean(g_std[f.name][:][pred>0]) if g_std is not None else None, | ||
| metric_name: m_dice, | ||
| 'uncertainty_score': np.mean(std[pred>0]), | ||
| 'image_path': f, | ||
| 'mask_path': self.label_fn(f), | ||
| 'softmax_path': f'{chunk_store}/{g_smx.path}/{f.name}', | ||
| 'engergy_path': f'{chunk_store}/{g_eng.path}/{f.name}' if g_eng is not None else None, | ||
| 'uncertainty_path': f'{chunk_store}/{g_std.path}/{f.name}' if g_std is not None else None}) | ||
| 'pred_path': f'{self.store}/{self.g_pred.path}/{f.name}', | ||
| 'softmax_path': f'{self.store}/{self.g_smx.path}/{f.name}', | ||
| 'uncertainty_path': f'{self.store}/{self.g_std.path}/{f.name}'}) | ||
| res_list.append(df_tmp) | ||
| if export_dir: | ||
| save_mask(pred, pred_path/f'{df_tmp.file}_{df_tmp.model}_mask', filetype) | ||
| if g_std is not None: | ||
| save_unc(g_std[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.model}_uncertainty', filetype) | ||
| if g_eng is not None: | ||
| save_unc(g_eng[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.model}_energy', filetype) | ||
| save_mask(pred, pred_path/f'{df_tmp.file}_model{df_tmp.model_no}_mask', filetype) | ||
| save_unc(std, unc_path/f'{df_tmp.file}_model{df_tmp.model_no}_uncertainty', filetype) | ||
| del self.inference_ensemble | ||
| if torch.cuda.is_available(): torch.cuda.empty_cache() | ||
| self.df_val = pd.DataFrame(res_list) | ||
@@ -551,3 +353,4 @@ if export_dir: | ||
| def show_valid_results(self, model_no=None, files=None, **kwargs): | ||
| def show_valid_results(self, model_no=None, files=None, metric_name='auto', **kwargs): | ||
| "Plot results of all or `file` validation images", | ||
| if self.df_val is None: self.get_valid_results(**kwargs) | ||
@@ -557,13 +360,14 @@ df = self.df_val | ||
| if model_no is not None: df = df[df.model_no==model_no] | ||
| if metric_name=='auto': metric_name = 'dice_score' if self.num_classes==2 else 'average_dice_score' | ||
| for _, r in df.iterrows(): | ||
| img = self.ds.get_data(r.image_path)[0][:] | ||
| msk = self.ds.get_data(r.image_path, mask=True)[0] | ||
| pred = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8') | ||
| std = zarr.load(r.uncertainty_path) | ||
| img = self.ds.data[r.file][:] | ||
| msk = self.ds.labels[r.file][:] | ||
| pred = self.g_pred[r.file][:] | ||
| std = self.g_std[r.file][:] | ||
| _d_model = f'Model {r.model_no}' | ||
| if self.tta: plot_results(img, msk, pred, std, df=r, model=_d_model) | ||
| else: plot_results(img, msk, pred, np.zeros_like(pred), df=r, model=_d_model) | ||
| plot_results(img, msk, pred, std, df=r, num_classes=self.num_classes, metric_name=metric_name, model=_d_model) | ||
| def load_ensemble(self, path=None): | ||
| path = path or self.ensemble_dir | ||
| def load_models(self, path=None): | ||
| "Get models saved at `path`" | ||
| path = path or self.ensemble_dir/'single_models' | ||
| models = sorted(get_files(path, extensions='.pth', recurse=False)) | ||
@@ -573,4 +377,4 @@ self.models = {} | ||
| for i, m in enumerate(models,1): | ||
| if i==0: self.n_classes = int(m.name.split('_')[2][0]) | ||
| else: assert self.n_classes==int(m.name.split('_')[2][0]), 'Check models. Models are trained on different number of classes.' | ||
| if i==0: self.num_classes = int(m.name.split('_')[2][0]) | ||
| else: assert self.num_classes==int(m.name.split('_')[2][0]), 'Check models. Models are trained on different number of classes.' | ||
| self.models[i] = m | ||
@@ -587,9 +391,53 @@ | ||
| def get_ensemble_results(self, files, zarr_store=None, export_dir=None, filetype='.png', **kwargs): | ||
| ep = EnsemblePredict(models_paths=self.models.values(), zarr_store=zarr_store) | ||
| g_smx, g_std, g_eng = ep.predict_images(files, bs=self.batch_size, ds_kwargs=self.pred_ds_kwargs, **kwargs) | ||
| chunk_store = g_smx.chunk_store.path | ||
| del ep | ||
| torch.cuda.empty_cache() | ||
| def lr_find(self, files=None, **kwargs): | ||
| "Wrapper function for learning rate finder" | ||
| files = files or self.files | ||
| dls = self._get_dls(files) | ||
| model = self._create_model() | ||
| learn = Learner(dls, model, metrics=self.metrics, wd=self.weight_decay, loss_func=self.loss_fn, opt_func=_optim_dict[self.optim]) | ||
| if self.mixed_precision_training: learn.to_fp16() | ||
| sug_lrs = learn.lr_find(**kwargs) | ||
| return sug_lrs, learn.recorder | ||
| # Cell | ||
| class EnsemblePredictor(EnsembleBase): | ||
| def __init__(self, *args, ensemble_path:Path=None, **kwargs): | ||
| if ensemble_path is not None: | ||
| self.load_inference_ensemble(ensemble_path) | ||
| super().__init__(*args, **kwargs) | ||
| if hasattr(self, 'inference_ensemble'): | ||
| self.config.num_classes = self.inference_ensemble.num_classes | ||
| if hasattr(self, 'files'): | ||
| self._create_ds(stats={}, use_zarr_data = False, verbose=1) | ||
| self.ensemble_dir = self.path/self.ens_dir | ||
| #if ensemble_path is not None: | ||
| # self.load_inference_ensemble(ensemble_path) | ||
| def load_inference_ensemble(self, ensemble_path:Path=None): | ||
| "Load inference_ensemble from `self.ensemle_dir` or from `path`" | ||
| path = ensemble_path or self.ensemble_dir | ||
| if path.is_dir(): | ||
| path_list = get_files(path, extensions='.pt', recurse=False) | ||
| if len(path_list)==0: | ||
| warnings.warn(f'No inference ensemble available at {path}. Did you train your ensemble correctly?') | ||
| return | ||
| path = path_list[0] | ||
| self.inference_ensemble_name = path.name | ||
| if hasattr(self, 'device'): self.inference_ensemble = torch.jit.load(path).to(self.device) | ||
| else: self.inference_ensemble = torch.jit.load(path) | ||
| print(f'Successfully loaded InferenceEnsemble from {path}') | ||
| def get_ensemble_results(self, file_list=None, export_dir=None, filetype='.png', **kwargs): | ||
| 'Predict files in file_list using InferenceEnsemble' | ||
| if file_list is not None: | ||
| self.files = file_list | ||
| self._create_ds(stats={}, use_zarr_data = False, verbose=1) | ||
| if export_dir: | ||
@@ -603,68 +451,78 @@ export_dir = Path(export_dir) | ||
| res_list = [] | ||
| for f in files: | ||
| pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8') | ||
| for f in progress_bar(self.files): | ||
| img = self.ds.read_img(f) | ||
| pred, smx, std = self.predict(img) | ||
| self.save_preds_zarr(f.name, pred, smx, std) | ||
| df_tmp = pd.Series({'file' : f.name, | ||
| 'ensemble' : self.model_name, | ||
| 'n_models' : len(self.models), | ||
| #'mean_energy': np.mean(g_eng[f.name][:][pred>0]), | ||
| 'uncertainty_score': np.mean(g_std[f.name][:][pred>0]) if g_std is not None else None, | ||
| 'ensemble' : self.inference_ensemble_name, | ||
| 'uncertainty_score': np.mean(std[pred>0]), | ||
| 'image_path': f, | ||
| 'softmax_path': f'{chunk_store}/{g_smx.path}/{f.name}', | ||
| 'uncertainty_path': f'{chunk_store}/{g_std.path}/{f.name}' if g_std is not None else None, | ||
| 'energy_path': f'{chunk_store}/{g_eng.path}/{f.name}' if g_eng is not None else None}) | ||
| 'pred_path': f'{self.store}/{self.g_pred.path}/{f.name}', | ||
| 'softmax_path': f'{self.store}/{self.g_smx.path}/{f.name}', | ||
| 'uncertainty_path': f'{self.store}/{self.g_std.path}/{f.name}'}) | ||
| res_list.append(df_tmp) | ||
| if export_dir: | ||
| save_mask(pred, pred_path/f'{df_tmp.file}_{df_tmp.ensemble}_mask', filetype) | ||
| if g_std is not None: | ||
| save_unc(g_std[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.ensemble}_unc', filetype) | ||
| if g_eng is not None: | ||
| save_unc(g_eng[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.ensemble}_energy', filetype) | ||
| save_mask(pred, pred_path/f'{df_tmp.file}_mask', filetype) | ||
| save_unc(std, unc_path/f'{df_tmp.file}_unc', filetype) | ||
| self.df_ens = pd.DataFrame(res_list) | ||
| return g_smx, g_std, g_eng | ||
| return self.g_pred, self.g_smx, self.g_std | ||
| def score_ensemble_results(self, mask_dir=None, label_fn=None): | ||
| if mask_dir is not None and label_fn is None: | ||
| label_fn = get_label_fn(self.df_ens.image_path[0], self.path/mask_dir) | ||
| for i, r in self.df_ens.iterrows(): | ||
| if label_fn is not None: | ||
| msk_path = self.label_fn(r.image_path) | ||
| msk = _read_msk(msk_path, n_classes=self.n_classes, instance_labels=self.instance_labels) | ||
| self.df_ens.loc[i, 'mask_path'] = msk_path | ||
| "Compare ensemble results to given segmentation masks." | ||
| if any(v is not None for v in (mask_dir, label_fn)): | ||
| self.label_fn = label_fn or self.get_label_fn(mask_dir) | ||
| self._create_ds(stats={}, use_zarr_data = False, verbose=1) | ||
| print('Calculating metrics') | ||
| for i, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)): | ||
| msk = self.ds.labels[r.file][:] | ||
| pred = self.g_pred[r.file][:] | ||
| if self.num_classes==2: | ||
| self.df_ens.loc[i, f'dice_score'] = binary_dice_score(msk, pred) | ||
| else: | ||
| msk = self.ds.labels[r.file][:] | ||
| pred = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8') | ||
| self.df_ens.loc[i, 'dice_score'] = dice_score(msk, pred) | ||
| for cl in range(self.num_classes): | ||
| msk_bin = msk==cl | ||
| pred_bin = pred==cl | ||
| if np.any([msk_bin, pred_bin]): | ||
| self.df_ens.loc[i, f'dice_score_class{cl}'] = binary_dice_score(msk_bin, pred_bin) | ||
| if self.num_classes>2: | ||
| self.df_ens['average_dice_score'] = self.df_ens[[col for col in self.df_ens if col.startswith('dice_score_class')]].mean(axis=1) | ||
| return self.df_ens | ||
| def show_ensemble_results(self, files=None, unc=True, unc_metric=None, metric_name='dice_score'): | ||
| def show_ensemble_results(self, files=None, unc=True, unc_metric=None, metric_name='auto'): | ||
| "Show result of ensemble or `model_no`" | ||
| assert self.df_ens is not None, "Please run `get_ensemble_results` first." | ||
| df = self.df_ens | ||
| if files is not None: df = df.reset_index().set_index('file', drop=False).loc[files] | ||
| if metric_name=='auto': metric_name = 'dice_score' if self.num_classes==2 else 'average_dice_score' | ||
| for _, r in df.iterrows(): | ||
| imgs = [] | ||
| imgs.append(_read_img(r.image_path)[:]) | ||
| imgs.append(self.ds.read_img(r.image_path)) | ||
| if metric_name in r.index: | ||
| try: msk = self.ds.labels[r.file][:] | ||
| except: msk = _read_msk(r.mask_path, n_classes=self.n_classes, instance_labels=self.instance_labels) | ||
| imgs.append(msk) | ||
| imgs.append(self.ds.labels[r.file][:]) | ||
| hastarget=True | ||
| else: | ||
| hastarget=False | ||
| imgs.append(np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8')) | ||
| if unc: imgs.append(zarr.load(r.uncertainty_path)) | ||
| plot_results(*imgs, df=r, hastarget=hastarget, metric_name=metric_name, unc_metric=unc_metric) | ||
| imgs.append(self.g_pred[r.file]) | ||
| if unc: imgs.append(self.g_std[r.file]) | ||
| plot_results(*imgs, df=r, hastarget=hastarget, num_classes=self.num_classes, metric_name=metric_name, unc_metric=unc_metric) | ||
| def get_cellpose_results(self, export_dir=None): | ||
| 'Get instance segmentation results using the cellpose integration' | ||
| assert self.df_ens is not None, "Please run `get_ensemble_results` first." | ||
| cl = self.cellpose_export_class | ||
| assert cl<self.n_classes, f'{cl} not avaialable from {self.n_classes} classes' | ||
| assert cl<self.num_classes, f'{cl} not avaialable from {self.num_classes} classes' | ||
| smxs, preds = [], [] | ||
| for _, r in self.df_ens.iterrows(): | ||
| softmax = zarr.load(r.softmax_path) | ||
| smxs.append(softmax) | ||
| preds.append(np.argmax(softmax, axis=-1).astype('uint8')) | ||
| smxs.append(self.g_smx[r.file][:]) | ||
| preds.append(self.g_pred[r.file][:]) | ||
| probs = [x[...,cl] for x in smxs] | ||
| probs = [x[cl] for x in smxs] | ||
| masks = [x==cl for x in preds] | ||
@@ -678,7 +536,6 @@ cp_masks = run_cellpose(probs, masks, | ||
| if export_dir: | ||
| export_dir = Path(export_dir) | ||
| cp_path = export_dir/'cellpose_masks' | ||
| cp_path.mkdir(parents=True, exist_ok=True) | ||
| export_dir = Path(export_dir)/'instance_labels' | ||
| export_dir.mkdir(parents=True, exist_ok=True) | ||
| for idx, r in self.df_ens.iterrows(): | ||
| tifffile.imwrite(cp_path/f'{r.file}_class{cl}.tif', cp_masks[idx], compress=6) | ||
| tifffile.imwrite(export_dir/f'{r.file}_class{cl}.tif', cp_masks[idx], compress=6) | ||
@@ -689,32 +546,31 @@ self.cellpose_masks = cp_masks | ||
| def score_cellpose_results(self, mask_dir=None, label_fn=None): | ||
| "Compare cellpose nstance segmentation results to given masks." | ||
| assert self.cellpose_masks is not None, 'Run get_cellpose_results() first' | ||
| if mask_dir is not None and label_fn is None: | ||
| label_fn = get_label_fn(self.df_ens.image_path[0], self.path/mask_dir) | ||
| if any(v is not None for v in (mask_dir, label_fn)): | ||
| self.label_fn = label_fn or self.get_label_fn(mask_dir) | ||
| self._create_ds(stats={}, use_zarr_data = False, verbose=1) | ||
| cl = self.cellpose_export_class | ||
| for i, r in self.df_ens.iterrows(): | ||
| if label_fn is not None: | ||
| msk_path = self.label_fn(r.image_path) | ||
| msk = _read_msk(msk_path, n_classes=self.n_classes, instance_labels=self.instance_labels) | ||
| self.df_ens.loc[i, 'mask_path'] = msk_path | ||
| else: | ||
| msk = self.ds.labels[r.file][:] | ||
| _, msk = cv2.connectedComponents(msk, connectivity=4) | ||
| msk = self.ds.labels[r.file][:]==cl | ||
| _, msk = cv2.connectedComponents(msk.astype('uint8'), connectivity=4) | ||
| pred = self.cellpose_masks[i] | ||
| ap, tp, fp, fn = get_instance_segmentation_metrics(msk, pred, is_binary=False, min_pixel=self.min_pixel_export) | ||
| self.df_ens.loc[i, 'mean_average_precision'] = ap.mean() | ||
| self.df_ens.loc[i, 'average_precision_at_iou_50'] = ap[0] | ||
| self.df_ens.loc[i, f'mAP_class{cl}'] = ap.mean() | ||
| self.df_ens.loc[i, f'mAP_iou50_class{cl}'] = ap[0] | ||
| return self.df_ens | ||
| def show_cellpose_results(self, files=None, unc=True, unc_metric=None, metric_name='mean_average_precision'): | ||
| def show_cellpose_results(self, files=None, unc_metric=None, metric_name='auto'): | ||
| 'Show instance segmentation results from cellpose predictions.' | ||
| assert self.df_ens is not None, "Please run `get_ensemble_results` first." | ||
| df = self.df_ens.reset_index() | ||
| if files is not None: df = df.set_index('file', drop=False).loc[files] | ||
| if metric_name=='auto': metric_name=f'mAP_class{self.cellpose_export_class}' | ||
| for _, r in df.iterrows(): | ||
| imgs = [] | ||
| imgs.append(_read_img(r.image_path)[:]) | ||
| imgs = [self.ds.read_img(r.image_path)] | ||
| if metric_name in r.index: | ||
| try: | ||
| mask = self.ds.labels[idx][:] | ||
| except: | ||
| mask = _read_msk(r.mask_path, n_classes=self.n_classes, instance_labels=self.instance_labels) | ||
| _, comps = cv2.connectedComponents((mask==self.cellpose_export_class).astype('uint8'), connectivity=4) | ||
| mask = self.ds.labels[r.file][:] | ||
| mask = (mask==self.cellpose_export_class).astype('uint8') | ||
| _, comps = cv2.connectedComponents(mask, connectivity=4) | ||
| imgs.append(label2rgb(comps, bg_label=0)) | ||
@@ -724,16 +580,9 @@ hastarget=True | ||
| hastarget=False | ||
| imgs.append(label2rgb(self.cellpose_masks[r['index']], bg_label=0)) | ||
| if unc: imgs.append(zarr.load(r.uncertainty_path)) | ||
| plot_results(*imgs, df=r, hastarget=hastarget, metric_name=metric_name, unc_metric=unc_metric) | ||
| imgs.append(self.g_std[r.file]) | ||
| plot_results(*imgs, df=r, hastarget=hastarget, num_classes=self.num_classes, instance_labels=True, metric_name=metric_name, unc_metric=unc_metric) | ||
| def lr_find(self, files=None, **kwargs): | ||
| files = files or self.files | ||
| dls = self._get_dls(files) | ||
| model = self._create_model() | ||
| learn = Learner(dls, model, metrics=self.metrics, wd=self.weight_decay, loss_func=self.loss_fn, opt_func=_optim_dict[self.optim]) | ||
| if self.mixed_precision_training: learn.to_fp16() | ||
| sug_lrs = learn.lr_find(**kwargs) | ||
| return sug_lrs, learn.recorder | ||
| def export_imagej_rois(self, output_folder='ROI_sets', **kwargs): | ||
| 'Export ImageJ ROI Sets to `ouput_folder`' | ||
| assert self.df_ens is not None, "Please run prediction first." | ||
@@ -744,40 +593,13 @@ | ||
| for idx, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)): | ||
| mask = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8') | ||
| uncertainty = zarr.load(r.uncertainty_path) | ||
| export_roi_set(mask, uncertainty, name=r.file, path=output_folder, ascending=False, **kwargs) | ||
| pred = self.g_pred[r.file][:] | ||
| uncertainty = self.g_std[r.file][:] | ||
| export_roi_set(pred, uncertainty, name=r.file, path=output_folder, ascending=False, **kwargs) | ||
| def export_cellpose_rois(self, output_folder='cellpose_ROI_sets', **kwargs): | ||
| 'Export cellpose predictions to ImageJ ROI Sets in `ouput_folder`' | ||
| output_folder = Path(output_folder) | ||
| output_folder.mkdir(exist_ok=True, parents=True) | ||
| for idx, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)): | ||
| mask = self.cellpose_masks[idx] | ||
| uncertainty = zarr.load(r.uncertainty_path) | ||
| export_roi_set(mask, uncertainty, instance_labels=True, name=r.file, path=output_folder, ascending=False, **kwargs) | ||
| #def clear_tmp(self): | ||
| # try: | ||
| # shutil.rmtree('/tmp/*', ignore_errors=True) | ||
| # shutil.rmtree(self.path/'.tmp') | ||
| # print(f'Deleted temporary files from {self.path/".tmp"}') | ||
| # except: print(f'No temporary files to delete at {self.path/".tmp"}') | ||
| # Cell | ||
| add_docs(EnsembleLearner, "Meta class to train and predict model ensembles with `n` models", | ||
| fit="Fit model number `i`", | ||
| fit_ensemble="Fit `i` models and `skip` existing", | ||
| get_valid_results="Validate models on validation data and save results", | ||
| show_valid_results="Plot results of all or `file` validation images", | ||
| get_ensemble_results="Get models and ensemble results", | ||
| score_ensemble_results="Compare ensemble results to given segmentation masks.", | ||
| show_ensemble_results="Show result of ensemble or `model_no`", | ||
| load_ensemble="Get models saved at `path`", | ||
| get_cellpose_results='Get instance segmentation results using the cellpose integration', | ||
| score_cellpose_results="Compare cellpose nstance segmentation results to given masks.", | ||
| show_cellpose_results='Show instance segmentation results from cellpose predictions.', | ||
| get_loss="Get loss function from loss name (config)", | ||
| set_n="Change to `n` models per ensemble", | ||
| lr_find="Wrapper for learning rate finder", | ||
| export_imagej_rois='Export ImageJ ROI Sets to `ouput_folder`', | ||
| export_cellpose_rois='Export cellpose predictions to ImageJ ROI Sets in `ouput_folder`', | ||
| #clear_tmp="Clear directory with temporary files" | ||
| ) | ||
| pred = self.cellpose_masks[idx] | ||
| uncertainty = self.g_std[r.file][:] | ||
| export_roi_set(pred, uncertainty, instance_labels=True, name=r.file, path=output_folder, ascending=False, **kwargs) |
+42
-2
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05_losses.ipynb (unless otherwise specified). | ||
| __all__ = ['LOSSES', 'FastaiLoss', 'WeightedLoss', 'JointLoss', 'get_loss'] | ||
| __all__ = ['LOSSES', 'FastaiLoss', 'WeightedLoss', 'JointLoss', 'Poly1CrossEntropyLoss', 'get_loss'] | ||
@@ -16,3 +16,3 @@ # Cell | ||
| # Cell | ||
| LOSSES = ['CrossEntropyLoss', 'DiceLoss', 'SoftCrossEntropyLoss', 'CrossEntropyDiceLoss', 'JaccardLoss', 'FocalLoss', 'LovaszLoss', 'TverskyLoss'] | ||
| LOSSES = ['CrossEntropyLoss', 'DiceLoss', 'SoftCrossEntropyLoss', 'CrossEntropyDiceLoss', 'JaccardLoss', 'FocalLoss', 'LovaszLoss', 'TverskyLoss', 'Poly1CrossEntropyLoss'] | ||
@@ -63,2 +63,39 @@ # Cell | ||
| # Cell | ||
| class Poly1CrossEntropyLoss(nn.Module): | ||
| def __init__(self, | ||
| num_classes: int, | ||
| epsilon: float = 1.0, | ||
| reduction: str = "mean"): | ||
| """ | ||
| Create instance of Poly1CrossEntropyLoss | ||
| :param num_classes: | ||
| :param epsilon: | ||
| :param reduction: one of none|sum|mean, apply reduction to final loss tensor | ||
| """ | ||
| super(Poly1CrossEntropyLoss, self).__init__() | ||
| self.num_classes = num_classes | ||
| self.epsilon = epsilon | ||
| self.reduction = reduction | ||
| return | ||
| def forward(self, logits, labels): | ||
| """ | ||
| Forward pass | ||
| :param logits: tensor of shape [BNHW] | ||
| :param labels: tensor of shape [BHW] | ||
| :return: poly cross-entropy loss | ||
| """ | ||
| labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=logits.device, | ||
| dtype=logits.dtype) | ||
| labels_onehot = torch.moveaxis(labels_onehot, -1, 1) | ||
| pt = torch.sum(labels_onehot * F.softmax(logits, dim=1), dim=1) | ||
| CE = F.cross_entropy(input=logits, target=labels, reduction='none') | ||
| poly1 = CE + self.epsilon * (1 - pt) | ||
| if self.reduction == "mean": | ||
| poly1 = poly1.mean() | ||
| elif self.reduction == "sum": | ||
| poly1 = poly1.sum() | ||
| return poly1 | ||
| # Cell | ||
| def get_loss(loss_name, mode='multiclass', classes=[1], smooth_factor=0., alpha=0.5, beta=0.5, gamma=2.0, reduction='mean', **kwargs): | ||
@@ -96,2 +133,5 @@ 'Load losses from based on loss_name' | ||
| elif loss_name=="Poly1CrossEntropyLoss": | ||
| loss = Poly1CrossEntropyLoss(num_classes=max(classes)+1) | ||
| return loss |
+36
-3
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_models.ipynb (unless otherwise specified). | ||
| __all__ = ['ARCHITECTURES', 'ENCODERS', 'get_pretrained_options', 'create_smp_model', 'save_smp_model', | ||
| 'load_smp_model', 'check_cellpose_installation', 'get_diameters', 'run_cellpose'] | ||
| __all__ = ['ARCHITECTURES', 'ENCODERS', 'get_pretrained_options', 'PATCH_UNET_DECODER', 'create_smp_model', | ||
| 'save_smp_model', 'load_smp_model', 'check_cellpose_installation', 'get_diameters', 'run_cellpose'] | ||
@@ -31,2 +31,24 @@ # Cell | ||
| # Cell | ||
| PATCH_UNET_DECODER = False | ||
| @patch | ||
| def forward(self:smp.decoders.unet.decoder.UnetDecoder, *features): | ||
| features = features[1:] # remove first skip with same spatial resolution | ||
| features = features[::-1] # reverse channels to start from head of encoder | ||
| head = features[0] | ||
| skips = features[1:] | ||
| x = self.center(head) | ||
| for i, decoder_block in enumerate(self.blocks): | ||
| skip = skips[i] if i < len(skips) else None | ||
| x = decoder_block(x, skip) | ||
| if PATCH_UNET_DECODER: | ||
| x = torch.nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) | ||
| return x | ||
| # Cell | ||
| def create_smp_model(arch, **kwargs): | ||
@@ -36,4 +58,15 @@ 'Create segmentation_models_pytorch model' | ||
| assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}' | ||
| is_convnext_encoder = kwargs['encoder_name'].startswith('tu-convnext') | ||
| assert not all((arch!="Unet", is_convnext_encoder)), 'ConvNeXt encoder can only be used with Unet' | ||
| if arch=="Unet": model = smp.Unet(**kwargs) | ||
| if arch=="Unet": | ||
| global PATCH_UNET_DECODER | ||
| if is_convnext_encoder: | ||
| kwargs['encoder_depth'] = 4 | ||
| kwargs['decoder_channels'] = (256, 128, 64, 16) | ||
| PATCH_UNET_DECODER = True | ||
| else: | ||
| PATCH_UNET_DECODER = False | ||
| model = smp.Unet(**kwargs) | ||
| elif arch=="UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs) | ||
@@ -40,0 +73,0 @@ elif arch=="MAnet":model = smp.MAnet(**kwargs) |
+60
-103
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/07_tta.ipynb (unless otherwise specified). | ||
| __all__ = ['rot90', 'hflip', 'vflip', 'BaseTransform', 'Chain', 'Transformer', 'Compose', 'Merger', 'HorizontalFlip', | ||
| 'VerticalFlip', 'Rotate90'] | ||
| __all__ = ['rot90', 'hflip', 'vflip', 'BaseTransform', 'HorizontalFlip', 'VerticalFlip', 'Rotate90', 'Chain', | ||
| 'Transformer', 'Compose'] | ||
@@ -9,3 +9,2 @@ # Cell | ||
| import itertools | ||
| from functools import partial | ||
| from typing import List, Optional, Union | ||
@@ -15,11 +14,14 @@ from fastcore.foundation import store_attr | ||
| # Cell | ||
| def rot90(x, k=1): | ||
| @torch.jit.script | ||
| def rot90(x:torch.Tensor, k:int=1): | ||
| "rotate batch of images by 90 degrees k times" | ||
| return torch.rot90(x, k, (2, 3)) | ||
| def hflip(x): | ||
| @torch.jit.script | ||
| def hflip(x:torch.Tensor): | ||
| "flip batch of images horizontally" | ||
| return x.flip(3) | ||
| def vflip(x): | ||
| @torch.jit.script | ||
| def vflip(x:torch.Tensor): | ||
| "flip batch of images vertically" | ||
@@ -29,83 +31,9 @@ return x.flip(2) | ||
| # Cell | ||
| class BaseTransform: | ||
| class BaseTransform(torch.nn.Module): | ||
| identity_param = None | ||
| def __init__(self, pname: str, params: Union[list, tuple]): store_attr() | ||
| class Chain: | ||
| def __init__(self, functions: List[callable]): | ||
| self.functions = functions or [] | ||
| def __call__(self, x): | ||
| for f in self.functions: | ||
| x = f(x) | ||
| return x | ||
| class Transformer: | ||
| def __init__(self, image_pipeline: Chain, mask_pipeline: Chain): | ||
| def __init__(self, pname: str, params: Union[list, tuple]): | ||
| super(BaseTransform, self).__init__() | ||
| store_attr() | ||
| def augment_image(self, image): | ||
| return self.image_pipeline(image) | ||
| def deaugment_mask(self, mask): | ||
| return self.mask_pipeline(mask) | ||
| class Compose: | ||
| def __init__(self, aug_transforms: List[BaseTransform]): | ||
| store_attr() | ||
| self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms])) | ||
| self.deaug_transforms = aug_transforms[::-1] | ||
| self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters] | ||
| def __iter__(self) -> Transformer: | ||
| for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters): | ||
| image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p}) | ||
| for t, p in zip(self.aug_transforms, aug_params)]) | ||
| mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p}) | ||
| for t, p in zip(self.deaug_transforms, deaug_params)]) | ||
| yield Transformer(image_pipeline=image_aug_chain, mask_pipeline=mask_deaug_chain) | ||
| def __len__(self) -> int: | ||
| return len(self.aug_transform_parameters) | ||
| # Cell | ||
| import matplotlib.pyplot as plt | ||
| class Merger: | ||
| def __init__(self): | ||
| self.output = [] | ||
| def append(self, x): | ||
| self.output.append(torch.as_tensor(x)) | ||
| def result(self, type='mean'): | ||
| s = torch.stack(self.output) | ||
| if type == 'max': | ||
| result = torch.max(s, dim=0)[0] | ||
| elif type == 'mean': | ||
| result = torch.mean(s, dim=0) | ||
| elif type == 'std': | ||
| result = torch.std(s, dim=0) | ||
| elif type == 'uncertainty': | ||
| # adapted from https://github.com/ykwon0407/UQ_BNN/blob/master/retina/utils.py | ||
| aleatoric_uncertainty = torch.mean(s * (1 - s), dim=0) | ||
| epistemic_uncertainty = torch.mean(s**2, dim=0) - torch.mean(s, dim=0)**2 | ||
| result = epistemic_uncertainty + aleatoric_uncertainty | ||
| elif type == 'aleatoric_uncertainty': | ||
| result = torch.mean(s * (1 - s), dim=0) | ||
| elif type == 'epistemic_uncertainty': | ||
| result = torch.mean(s**2, dim=0) - torch.mean(s, dim=0)**2 | ||
| elif type == 'entropy': | ||
| result = -torch.sum(s * torch.log(s), dim=0) | ||
| else: | ||
| raise ValueError('Not correct merge type `{}`.'.format(self.type)) | ||
| return result | ||
| # Cell | ||
| class HorizontalFlip(BaseTransform): | ||
@@ -115,12 +43,8 @@ "Flip images horizontally (left->right)" | ||
| def __init__(self): | ||
| super().__init__("apply", [False, True]) | ||
| super(HorizontalFlip, self).__init__("apply", [0, 1]) | ||
| def apply_aug_image(self, image, apply=False, **kwargs): | ||
| if apply: image = hflip(image) | ||
| return image | ||
| def forward(self, x:torch.Tensor, apply:int=0, deaug:bool=False): | ||
| if apply==1: x = hflip(x) | ||
| return x | ||
| def apply_deaug_mask(self, mask, apply=False, **kwargs): | ||
| if apply: mask = hflip(mask) | ||
| return mask | ||
| # Cell | ||
@@ -131,12 +55,9 @@ class VerticalFlip(BaseTransform): | ||
| def __init__(self): | ||
| super().__init__("apply", [False, True]) | ||
| super().__init__("apply", [0, 1]) | ||
| def apply_aug_image(self, image, apply=False, **kwargs): | ||
| if apply: image = vflip(image) | ||
| return image | ||
| def forward(self, x:torch.Tensor, apply:int=0, deaug:bool=False): | ||
| if apply==1: | ||
| x = vflip(x) | ||
| return x | ||
| def apply_deaug_mask(self, mask, apply=False, **kwargs): | ||
| if apply: mask = vflip(mask) | ||
| return mask | ||
| # Cell | ||
@@ -147,11 +68,47 @@ class Rotate90(BaseTransform): | ||
| def __init__(self, angles: List[int]): | ||
| super().__init__("angle", angles) | ||
| if self.identity_param not in angles: | ||
| angles = [self.identity_param] + list(angles) | ||
| super().__init__("angle", angles) | ||
| def apply_aug_image(self, image, angle=0, **kwargs): | ||
| @torch.jit.export | ||
| def apply_aug_image(self, image:torch.Tensor, angle:int=0): #, **kwargs | ||
| k = angle // 90 if angle >= 0 else (angle + 360) // 90 | ||
| #k = torch.div(angle, 90, rounding_mode='trunc') if angle >= 0 else torch.div((angle + 360), 90, rounding_mode='trunc') | ||
| return rot90(image, k) | ||
| def apply_deaug_mask(self, mask, angle=0, **kwargs): | ||
| return self.apply_aug_image(mask, -angle) | ||
| def forward(self, x:torch.Tensor, angle:int=0, deaug:bool=False): | ||
| return self.apply_aug_image(x, angle=-angle if deaug else angle) | ||
| # Cell | ||
| class Chain(torch.nn.Module): | ||
| def __init__(self, transforms: List[BaseTransform]): | ||
| super().__init__() | ||
| self.transforms = torch.nn.ModuleList(transforms) | ||
| def forward(self, x, args:List[int], deaug:bool=False): | ||
| for i, t in enumerate(self.transforms): | ||
| x = t(x, args[i], deaug) | ||
| return x | ||
| # Cell | ||
| class Transformer(torch.nn.Module): | ||
| def __init__(self, transforms: List[BaseTransform], args:List[int]): | ||
| super(Transformer, self).__init__() | ||
| self.aug_pipeline = Chain(transforms) | ||
| self.deaug_pipeline = Chain(transforms[::-1]) | ||
| self.args = args | ||
| @torch.jit.export | ||
| def augment(self, image:torch.Tensor): | ||
| return self.aug_pipeline(image, self.args, deaug=False) | ||
| @torch.jit.export | ||
| def deaugment(self, mask:torch.Tensor): | ||
| return self.deaug_pipeline(mask, self.args[::-1], deaug=True) | ||
| # Cell | ||
| class Compose(torch.nn.Module): | ||
| def __init__(self, aug_transforms: List[BaseTransform]): | ||
| super(Compose, self).__init__() | ||
| self.transform_parameters = list(itertools.product(*[t.params for t in aug_transforms])) | ||
| self.items = torch.nn.ModuleList([Transformer(aug_transforms, args) for args in self.transform_parameters]) |
+99
-57
| # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/06_utils.ipynb (unless otherwise specified). | ||
| __all__ = ['unzip', 'download_sample_data', 'install_package', 'import_package', 'compose_albumentations', | ||
| 'ensemble_results', 'plot_results', 'iou', 'dice_score', 'label_mask', 'get_instance_segmentation_metrics', | ||
| 'export_roi_set', 'calc_iterations', 'get_label_fn', 'save_mask', 'save_unc'] | ||
| __all__ = ['unzip', 'download_sample_data', 'install_package', 'import_package', 'compose_albumentations', 'clean_show', | ||
| 'plot_results', 'multiclass_dice_score', 'binary_dice_score', 'dice_score', 'label_mask', | ||
| 'get_instance_segmentation_metrics', 'export_roi_set', 'calc_iterations', 'get_label_fn', 'save_mask', | ||
| 'save_unc'] | ||
@@ -11,2 +12,3 @@ # Cell | ||
| from pathlib import Path | ||
| from scipy import ndimage | ||
@@ -19,5 +21,12 @@ from scipy.spatial.distance import jaccard | ||
| from scipy.optimize import linear_sum_assignment | ||
| from sklearn.metrics import jaccard_score | ||
| from sklearn.metrics import jaccard_score, multilabel_confusion_matrix | ||
| from sklearn.metrics._classification import _prf_divide | ||
| import matplotlib as mpl | ||
| import matplotlib.pyplot as plt | ||
| from mpl_toolkits.axes_grid1 import make_axes_locatable | ||
| import albumentations as A | ||
| from fastcore.foundation import patch | ||
@@ -85,16 +94,41 @@ from fastcore.meta import delegates | ||
| # Cell | ||
| def ensemble_results(res_dict, file, std=False): | ||
| "Combines single model predictions." | ||
| idx = 2 if std else 0 | ||
| a = [np.array(res_dict[(mod, f)][idx]) for mod, f in res_dict if f==file] | ||
| a = np.mean(a, axis=0) | ||
| if std: | ||
| a = a[...,0] | ||
| def clean_show(ax, msk, title, cmap, cbar=None, ticks=None, **kwargs): | ||
| img = ax.imshow(msk, cmap=cmap, **kwargs) | ||
| #if cbar is not None: | ||
| divider = make_axes_locatable(ax) | ||
| cax = divider.append_axes("right", size="5%", pad=0.05) | ||
| if cbar=='experts': | ||
| scale = ticks/(ticks+1) | ||
| cbr = plt.colorbar(img, cax=cax, ticks=[i*(scale)+(scale/2) for i in range(ticks+1)]) | ||
| cbr.set_ticklabels([i for i in range(ticks+1)]) | ||
| cbr.set_label('# of experts', rotation=270, labelpad=+15, fontsize="larger") | ||
| elif cbar=='classes': | ||
| scale = ticks/(ticks) | ||
| bounds = [i for i in range(ticks+1)] | ||
| cmap = plt.cm.get_cmap(cmap) | ||
| norm = mpl.colors.BoundaryNorm(bounds, cmap.N) | ||
| cbr = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), | ||
| cax=cax, ticks=[i*(scale)+(scale/2) for i in range(ticks)]) | ||
| cbr.set_ticklabels([i for i in range(ticks)]) | ||
| cbr.set_label('Classes', rotation=270, labelpad=+15, fontsize="larger") | ||
| elif cbar=='uncertainty': | ||
| cbr = plt.colorbar(img, cax=cax) | ||
| #cbr.set_ticklabels([i for i in range(ticks)]) | ||
| #cbr.set_label('Uncertainty', rotation=270, labelpad=+15, fontsize="larger") | ||
| else: | ||
| a = np.argmax(a, axis=-1) | ||
| return a | ||
| cbr = plt.colorbar(img, cax=cax) | ||
| cax.set_axis_off() | ||
| cbr.remove() | ||
| ax.set_axis_off() | ||
| ax.set_title(title) | ||
| # Cell | ||
| def plot_results(*args, df, hastarget=False, model=None, metric_name='dice_score', unc_metric=None, figsize=(20, 20), **kwargs): | ||
| def plot_results(*args, df, hastarget=False, num_classes=2, model=None, instance_labels=False, | ||
| metric_name='dice_score', unc_metric=None, figsize=(20, 20), msk_cmap='viridis', **kwargs): | ||
| "Plot images, (masks), predictions and uncertainties side-by-side." | ||
| vkwargs = {'vmin':0, 'vmax':num_classes-1} if not instance_labels else {} | ||
| unc_vkwargs = {'vmin':0, 'vmax':1} | ||
| class_kwargs = {'cbar':'classes', 'ticks':num_classes} if not instance_labels else {} | ||
| if len(args)==4: | ||
@@ -113,35 +147,20 @@ img, msk, pred, pred_std = args | ||
| img=img[...,0] | ||
| axs[0].imshow(img) | ||
| axs[0].set_axis_off() | ||
| axs[0].set_title(f'File {df.file}') | ||
| clean_show(axs[0], img, f'File {df.file}', None) | ||
| unc_title = f'Uncertainty \n {unc_metric}: {df[unc_metric]:.3f}' if unc_metric else 'Uncertainty' | ||
| pred_title = 'Prediction' if model is None else f'Prediction {model}' | ||
| if len(args)==4: | ||
| axs[1].imshow(msk) | ||
| axs[1].set_axis_off() | ||
| axs[1].set_title('Target') | ||
| axs[2].imshow(pred) | ||
| axs[2].set_axis_off() | ||
| axs[2].set_title(f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}') | ||
| axs[3].imshow(pred_std) | ||
| axs[3].set_axis_off() | ||
| axs[3].set_title(unc_title) | ||
| clean_show(axs[1], msk, 'Target', msk_cmap, cbar='classes', ticks=num_classes, **vkwargs, **kwargs) | ||
| pred_title = f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}' | ||
| clean_show(axs[2], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs) | ||
| clean_show(axs[3], pred_std, unc_title, 'hot', cbar='uncertainty', **unc_vkwargs, **kwargs) | ||
| elif len(args)==3 and not hastarget: | ||
| axs[1].imshow(pred) | ||
| axs[1].set_axis_off() | ||
| axs[1].set_title(pred_title) | ||
| axs[2].imshow(pred_std) | ||
| axs[2].set_axis_off() | ||
| axs[2].set_title(unc_title) | ||
| clean_show(axs[1], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs) | ||
| clean_show(axs[2], pred_std, unc_title, 'hot', cbar='uncertainty', **unc_vkwargs, **kwargs) | ||
| elif len(args)==3: | ||
| axs[1].imshow(msk) | ||
| axs[1].set_axis_off() | ||
| axs[1].set_title('Target') | ||
| axs[2].imshow(pred) | ||
| axs[2].set_axis_off() | ||
| axs[2].set_title(f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}') | ||
| clean_show(axs[1], msk, 'Target', msk_cmap, cbar='classes', ticks=num_classes,**vkwargs, **kwargs) | ||
| pred_title = f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}' | ||
| clean_show(axs[2], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs) | ||
| elif len(args)==2: | ||
| axs[1].imshow(pred) | ||
| axs[1].set_axis_off() | ||
| axs[1].set_title(pred_title) | ||
| clean_show(axs[2], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs) | ||
| plt.show() | ||
@@ -172,21 +191,44 @@ | ||
| # Cell | ||
| def iou(a,b,threshold=0.5, average='macro', **kwargs): | ||
| '''Computes the Intersection-Over-Union metric.''' | ||
| a = np.array(a).flatten() | ||
| b = np.array(b).flatten() | ||
| if a.max()>1 or b.max()>1: | ||
| return jaccard_score(a, b, average=average, **kwargs) | ||
| else: | ||
| a = np.array(a) > threshold | ||
| b = np.array(b) > threshold | ||
| overlap = a*b # Logical AND | ||
| union = a+b # Logical OR | ||
| return np.divide(np.count_nonzero(overlap),np.count_nonzero(union)) | ||
| def multiclass_dice_score(y_true, y_pred, average='macro', **kwargs): | ||
| '''Computes the Sørensen–Dice coefficient for multiclass segmentations.''' | ||
| # Cell | ||
| def dice_score(*args, **kwargs): | ||
| '''Computes the Dice coefficient metric.''' | ||
| iou_score = iou(*args, **kwargs) | ||
| average_options = (None, 'micro', 'macro') | ||
| if average not in average_options: | ||
| raise ValueError('average has to be one of ' + str(average_options)) | ||
| MCM = multilabel_confusion_matrix(y_true, y_pred, **kwargs) | ||
| numerator = 2*MCM[:, 1, 1] | ||
| denominator = 2*MCM[:, 1, 1] + MCM[:, 0, 1] + MCM[:, 1, 0] | ||
| if average == 'micro': | ||
| numerator = np.array([numerator.sum()]) | ||
| denominator = np.array([denominator.sum()]) | ||
| dice_scores = _prf_divide(numerator, denominator, '', None, None, '') | ||
| if average is None: | ||
| return dice_scores | ||
| return np.average(dice_scores) | ||
| def binary_dice_score(y_true, y_pred): | ||
| '''Compute the Sørensen–Dice coefficient for binary segmentations.''' | ||
| overlap = y_true*y_pred # Logical AND | ||
| union = y_true+y_pred # Logical OR | ||
| iou_score = np.divide(np.count_nonzero(overlap),np.count_nonzero(union)) # | ||
| return 2*iou_score/(iou_score+1) | ||
| def dice_score(y_true, y_pred, average='macro', num_classes=2, **kwargs): | ||
| '''Computes the Sørensen–Dice coefficient.''' | ||
| y_true = np.array(y_true).flatten() | ||
| y_pred = np.array(y_pred).flatten() | ||
| if y_true.max()>1 or y_pred.max()>1 or num_classes>2: | ||
| labels = [i for i in range(num_classes)] | ||
| return multiclass_dice_score(y_true, y_pred, average=average, labels=labels, **kwargs) | ||
| else: | ||
| return binary_dice_score(y_true, y_pred) | ||
| # Cell | ||
@@ -193,0 +235,0 @@ def label_mask(mask, threshold=0.5, connectivity=4, min_pixel=0, do_watershed=False, exclude_border=False): |
+2
-7
| Metadata-Version: 2.1 | ||
| Name: deepflash2 | ||
| Version: 0.1.8 | ||
| Version: 0.2.0 | ||
| Summary: A Deep learning pipeline for segmentation of fluorescent labels in microscopy images | ||
@@ -10,3 +10,2 @@ Home-page: https://github.com/matjesg/deepflash2 | ||
| Keywords: unet,deep learning,semantic segmentation,microscopy,fluorescent labels | ||
| Platform: UNKNOWN | ||
| Classifier: Development Status :: 3 - Alpha | ||
@@ -16,10 +15,8 @@ Classifier: Intended Audience :: Developers | ||
| Classifier: Natural Language :: English | ||
| Classifier: Programming Language :: Python :: 3.6 | ||
| Classifier: Programming Language :: Python :: 3.7 | ||
| Classifier: Programming Language :: Python :: 3.8 | ||
| Requires-Python: >=3.6 | ||
| Requires-Python: >=3.7 | ||
| Description-Content-Type: text/markdown | ||
| License-File: LICENSE | ||
| # Welcome to | ||
@@ -180,3 +177,1 @@ | ||
| The ImagJ-Macro is available [here](https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/ImageJ/Macro_create_maps.ijm). | ||
+0
-1
@@ -1,2 +0,1 @@ | ||
| # Welcome to | ||
@@ -3,0 +2,0 @@ |
+5
-5
@@ -11,4 +11,4 @@ [DEFAULT] | ||
| branch = master | ||
| version = 0.1.8 | ||
| min_python = 3.6 | ||
| version = 0.2.0 | ||
| min_python = 3.7 | ||
| audience = Developers | ||
@@ -19,5 +19,5 @@ language = English | ||
| status = 2 | ||
| requirements = fastai>=2.1.7 zarr>=2.0 scikit-image imageio ipywidgets openpyxl albumentations>=1.0.0 natsort>=7.1.1 numba>=0.52.0 segmentation-models-pytorch>=0.2 opencv-python-headless>=4.1.1,<4.5.5 | ||
| #pip_requirements = | ||
| #conda_requirements = | ||
| requirements = fastai>=2.1.7 zarr>=2.0 scikit-image imageio ipywidgets openpyxl imagecodecs albumentations>=1.0.0 natsort>=7.1.1 numba>=0.52.0 opencv-python-headless>=4.1.1,<4.5.5 timm>=0.5.4 | ||
| pip_requirements = segmentation-models-pytorch-deepflash2 | ||
| conda_requirements = segmentation_models_pytorch>=0.3.0 | ||
| nbs_path = nbs | ||
@@ -24,0 +24,0 @@ doc_path = docs |
| [console_scripts] | ||
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
236776
5.85%27
3.85%3775
6.4%