You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

deepflash2

Package Overview
Dependencies
Maintainers
1
Versions
27
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

deepflash2 - pypi Package Compare versions

Comparing version
0.1.8
to
0.2.0
+127
deepflash2/config.py
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_config.ipynb (unless otherwise specified).
__all__ = ['Config']
# Cell
from dataclasses import dataclass, asdict
from pathlib import Path
import json
import torch
# Cell
@dataclass
class Config:
"Config class for settings."
# Project
project_dir:str = '.'
# GT Estimation Settings
# staple_thres:float = 0.5
# staple_fval:int= 1
vote_undec:int = 0
# Train General Settings
n_models:int = 5
max_splits:int=5
random_state:int = 42
use_gpu:bool = True
# Pytorch Segmentation Model Settings
arch:str = 'Unet'
encoder_name:str = 'tu-convnext_tiny'
encoder_weights:str = 'imagenet'
# Train Data Settings
num_classes:int = 2
tile_shape:int = 512
scale:float = 1.
instance_labels:bool = False
# Train Settings
base_lr:float = 0.001
batch_size:int = 4
weight_decay:float = 0.001
mixed_precision_training:bool = True
optim:str = 'Adam'
loss:str = 'CrossEntropyDiceLoss'
n_epochs:int = 25
sample_mult:int = 0
# Train Data Augmentation
gamma_limit_lower:int = 80
gamma_limit_upper:int = 120
CLAHE_clip_limit:float = 0.0
brightness_limit:float = 0.0
contrast_limit:float = 0.0
flip:bool = True
rot:int = 360
distort_limit:float = 0
# Loss Settings
mode:str = 'multiclass' #currently only tested for multiclass
loss_alpha:float = 0.5 # Twerksky/Focal loss
loss_beta:float = 0.5 # Twerksy Loss
loss_gamma:float = 2.0 # Focal loss
loss_smooth_factor:float = 0. #SoftCrossEntropyLoss
# Pred/Val Settings
use_tta:bool = True
max_tile_shift: float = 0.5
border_padding_factor:float = 0.25
use_gaussian: bool = True
gaussian_kernel_sigma_scale: float = 0.125
min_pixel_export:int = 0
# Instance Segmentation Settings
cellpose_model:str='nuclei'
cellpose_diameter:int=0
cellpose_export_class:int=1
instance_segmentation_metrics:bool=False
# Folder Structure
gt_dir:str = 'GT_Estimation'
train_dir:str = 'Training'
pred_dir:str = 'Prediction'
ens_dir:str = 'models'
val_dir:str = 'valid'
def __post_init__(self):
self.set_device()
def set_device(self, device:str=None):
if device is None:
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
@property
def albumentation_kwargs(self):
kwargs = ['gamma_limit_lower', 'gamma_limit_upper', 'CLAHE_clip_limit',
'brightness_limit', 'contrast_limit', 'distort_limit']
return dict(filter(lambda x: x[0] in kwargs, self.__dict__.items()))
@property
def inference_kwargs(self):
inference_kwargs = ['use_tta', 'max_tile_shift', 'use_gaussian', 'scale',
'gaussian_kernel_sigma_scale', 'border_padding_factor']
return dict(filter(lambda x: x[0] in inference_kwargs, self.__dict__.items()))
def save(self, path):
'Save configuration to path'
path = Path(path).with_suffix('.json')
with open(path, 'w') as config_file:
json.dump(asdict(self), config_file)
print(f'Saved current configuration to {path}.json')
return path
def load(self, path):
'Load configuration from path'
path = Path(path)
try:
with open(path) as config_file: c = json.load(config_file)
if not Path(c['project_dir']).is_dir(): c['project_dir']='deepflash2'
for k,v in c.items(): setattr(self, k, v)
print(f'Successsfully loaded configuration from {path}')
except:
print('Error! Select valid config file (.json)')
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/04_inference.ipynb (unless otherwise specified).
__all__ = ['torch_gaussian', 'gaussian_kernel_2d', 'epistemic_uncertainty', 'aleatoric_uncertainty', 'uncertainty',
'get_in_slices_1d', 'get_out_slices_1d', 'TileModule', 'InferenceEnsemble']
# Cell
from typing import Tuple, List
import torch
import torch.nn.functional as F
from torchvision.transforms import Normalize
import deepflash2.tta as tta
# Cell
# adapted from https://github.com/scipy/scipy/blob/f2ec91c4908f9d67b5445fbfacce7f47518b35d1/scipy/signal/windows.py#L976
@torch.jit.script
def torch_gaussian(M:int, std:float, sym:bool=True) ->torch.Tensor:
'Returns a Gaussian window'
assert M > 2, 'Kernel size must be greater than 2.'
odd = M % 2
if not sym and not odd:
M = M + 1
n = torch.arange(0, M) - (M - 1.0) / 2.0
sig2 = 2 * std * std
w = torch.exp(-n ** 2 / sig2)
if not sym and not odd:
w = w[:-1]
return w
# Cell
@torch.jit.script
def gaussian_kernel_2d(patch_size: Tuple[int, int], sigma_scale:float=1/8) ->torch.Tensor:
'Returns a 2D Gaussian kernel tensor.'
patch_size = [patch_size[0], patch_size[1]]
sigmas = [i * sigma_scale for i in patch_size]
gkern1ds = [torch_gaussian(kernlen, std=std) for kernlen, std in zip(patch_size, sigmas)]
gkern2d = torch.outer(gkern1ds[0], gkern1ds[1])
gkern2d = gkern2d / gkern2d.max()
gkern2d[gkern2d==0] = gkern2d.min()
return gkern2d
# Cell
# adapted from https://github.com/ykwon0407/UQ_BNN/blob/master/retina/utils.py
@torch.jit.script
def epistemic_uncertainty(x: torch.Tensor):
return torch.mean(x**2, dim=0) - torch.mean(x, dim=0)**2
@torch.jit.script
def aleatoric_uncertainty(x: torch.Tensor):
return torch.mean(x * (1 - x), dim=0)
@torch.jit.script
def uncertainty(x: torch.Tensor):
# Add uncertainties
uncertainty = epistemic_uncertainty(x) + aleatoric_uncertainty(x)
# Scale to 1 max overall
uncertainty /= 0.25
return uncertainty
# Cell
@torch.jit.export
def get_in_slices_1d(center:torch.Tensor, len_x:int, len_tile:int) ->torch.Tensor:
start = (len_tile/2-center).clip(0).to(torch.int64)
stop = torch.tensor(len_tile).clip(max=(len_x-center+len_tile/2).to(torch.int64))
return torch.stack((start, stop))
@torch.jit.export
def get_out_slices_1d(center:torch.Tensor, len_x:int, len_tile:int) ->torch.Tensor:
start = (center - (len_tile/2)).clip(0, len_x).to(torch.int64)
stop = (center + (len_tile/2)).clip(max=len_x).to(torch.int64)
return torch.stack((start, stop))
# Cell
class TileModule(torch.nn.Module):
"Class for tiling data."
def __init__(self,
tile_shape = (512, 512),
scale:float = 1.,
border_padding_factor:float = 0.25,
max_tile_shift:float = 0.5):
super(TileModule, self).__init__()
self.tile_shape = tile_shape
self.scale = scale
self.border_padding_factor = border_padding_factor
self.max_tile_shift = max_tile_shift
grid_range = [torch.linspace(-self.scale, self.scale, steps=d) for d in tile_shape]
self.deformationField = torch.meshgrid(*grid_range, indexing='ij')
@torch.jit.export
def get_centers_1d(self, len_x:int, len_tile:int) ->torch.Tensor:
len_padding = float(len_tile*self.border_padding_factor)
start_point = len_tile/2 - len_padding
end_point = len_x - start_point
#end_point = max(len_x - start_point, start_point)
n_points = int((len_x+2*len_padding)//(len_tile*self.max_tile_shift))+1
return torch.linspace(start_point, end_point, n_points, dtype=torch.int64)
@torch.jit.export
def get_center_combinations(self, shape:List[int]) ->torch.Tensor:
c_list = [self.get_centers_1d(shape[i], self.tile_shape[i]) for i in range(2)]
center_combinations = torch.meshgrid(c_list[0],c_list[1], indexing='ij')
return torch.stack(center_combinations).permute([2, 1, 0]).reshape(-1,2)
@torch.jit.export
def get_slices_and_centers(self, shape:List[int]) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]:
shape = [int(shape[i]/self.scale) for i in range(2)]
center_combinations = self.get_center_combinations(shape)
in_slices = [get_in_slices_1d(center_combinations[:,i], shape[i], self.tile_shape[i]) for i in range(2)]
out_slices = [get_out_slices_1d(center_combinations[:,i], shape[i], self.tile_shape[i]) for i in range(2)]
scaled_centers = (center_combinations*self.scale).type(torch.int64)
return in_slices, out_slices, scaled_centers
@torch.jit.export
def forward(self, x, center:torch.Tensor) ->torch.Tensor:
"Apply deformation field to image using interpolation"
# Align grid to relative position and scale
grids = []
for i in range(2):
s = x.shape[i]
scale_ratio = self.tile_shape[i]/s
relative_center = (center[i]-s/2)/(s/2)
coords = (self.deformationField[i]*scale_ratio)+relative_center
grids.append(coords.to(x))
# grid with shape (N, H, W, 2)
vgrid = torch.stack(grids[::-1], dim=-1).to(x).unsqueeze_(0)
# input with shape (N, C, H, W)
x = x.permute(2,0,1).unsqueeze_(0)
# Remap
x = torch.nn.functional.grid_sample(x, vgrid, mode='nearest', padding_mode='reflection', align_corners=False)
return x
# Cell
class InferenceEnsemble(torch.nn.Module):
'Class for model ensemble inference'
def __init__(self,
models:List[torch.nn.Module],
num_classes:int,
in_channels:int,
channel_means:List[float],
channel_stds:List[float],
tile_shape:Tuple[int,int]=(512,512),
use_gaussian: bool = True,
gaussian_kernel_sigma_scale:float = 1./8,
use_tta:bool=True,
border_padding_factor:float = 0.25,
max_tile_shift:float = 0.9,
scale:float = 1.,
device:str='cpu'):
super().__init__()
self.num_classes = num_classes
self.use_tta = use_tta
self.use_gaussian = use_gaussian
self.gaussian_kernel_sigma_scale = gaussian_kernel_sigma_scale
self.tile_shape = tile_shape
self.norm = Normalize(channel_means, channel_stds)
dummy_input = torch.rand(1, in_channels, *self.tile_shape).to(device)
self.models = torch.nn.ModuleList([torch.jit.trace(m.to(device).eval(), dummy_input) for m in models])
self.tiler = torch.jit.script(TileModule(tile_shape=tile_shape,
scale=scale,
border_padding_factor=border_padding_factor,
max_tile_shift=max_tile_shift))
mw = gaussian_kernel_2d(tile_shape, gaussian_kernel_sigma_scale) if use_gaussian else torch.ones(tile_shape[0], tile_shape[1])
self.register_buffer('mw', mw)
tfms = [tta.HorizontalFlip(),tta.VerticalFlip()] if self.use_tta else []
self.tta_tfms = tta.Compose(tfms)
def forward(self, x):
# Extract image shape (assuming HWC)
sh = x.shape[:-1]
# Workaround for sh_scaled = [int(s/self.tiler.scale) for s in img_shape]
sh_scaled = (torch.tensor(sh)/self.tiler.scale).to(torch.int64)
sh_scaled = [int(t.item()) for t in sh_scaled]
# Create zero arrays (only on CPU RAM to avoid GPU memory overflow on large images)
# softmax = torch.zeros((sh_scaled[0], sh_scaled[1], self.num_classes), dtype=torch.float32, device=x.device)
softmax = torch.zeros((self.num_classes, sh_scaled[0], sh_scaled[1]), dtype=torch.float32, device=x.device)
merge_map = torch.zeros((sh_scaled[0], sh_scaled[1]), dtype=torch.float32, device=x.device)
stdeviation = torch.zeros((sh_scaled[0], sh_scaled[1]), dtype=torch.float32, device=x.device)
# Get slices for tiling
in_slices, out_slices, center_points = self.tiler.get_slices_and_centers(sh)
#
self.mw.to(x)
# Loop over tiles
for i, cp in enumerate(center_points):
tile = self.tiler(x, cp)
# Normalize
tile = self.norm(tile)
smxs = []
# Loop over tt-augmentations
for t in self.tta_tfms.items:
aug_tile = t.augment(tile)
# Loop over models
for model in self.models:
logits = model(aug_tile)
logits = t.deaugment(logits)
smxs.append(F.softmax(logits, dim=1))
smxs = torch.stack(smxs)
ix0, ix1, iy0, iy1 = in_slices[0][0][i], in_slices[0][1][i], in_slices[1][0][i], in_slices[1][1][i]
ox0, ox1, oy0, oy1 = out_slices[0][0][i], out_slices[0][1][i], out_slices[1][0][i], out_slices[1][1][i]
# Apply weigthing
batch_smx = torch.mean(smxs, dim=0)*self.mw.view(1,1,self.mw.shape[0],self.mw.shape[1])
#softmax[ox0:ox1, oy0:oy1] += batch_smx.permute(0,2,3,1)[0][ix0:ix1, iy0:iy1].to(softmax)
softmax[..., ox0:ox1, oy0:oy1] += batch_smx[0, ..., ix0:ix1, iy0:iy1].to(softmax)
merge_map[ox0:ox1, oy0:oy1] += self.mw[ix0:ix1, iy0:iy1].to(merge_map)
# Encertainty_estimates
batch_std = torch.mean(uncertainty(smxs), dim=1)*self.mw.view(1,self.mw.shape[0],self.mw.shape[1])
stdeviation[ox0:ox1, oy0:oy1] += batch_std[0][ix0:ix1, iy0:iy1].to(stdeviation)
# Normalize weighting
softmax /= torch.unsqueeze(merge_map, 0)
stdeviation /= merge_map
# Rescale results
if self.tiler.scale!=1.:
# Needs checking if these are the best options
softmax = F.interpolate(softmax.unsqueeze_(0), size=sh, mode="bilinear", align_corners=False)[0]
stdeviation = stdeviation.view(1, 1, stdeviation.shape[0], stdeviation.shape[1])
stdeviation = F.interpolate(stdeviation, size=sh, mode="bilinear", align_corners=False)[0][0]
argmax = torch.argmax(softmax, dim=0).to(torch.uint8)
return argmax, softmax, stdeviation
+2
-7
Metadata-Version: 2.1
Name: deepflash2
Version: 0.1.8
Version: 0.2.0
Summary: A Deep learning pipeline for segmentation of fluorescent labels in microscopy images

@@ -10,3 +10,2 @@ Home-page: https://github.com/matjesg/deepflash2

Keywords: unet,deep learning,semantic segmentation,microscopy,fluorescent labels
Platform: UNKNOWN
Classifier: Development Status :: 3 - Alpha

@@ -16,10 +15,8 @@ Classifier: Intended Audience :: Developers

Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Requires-Python: >=3.6
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
# Welcome to

@@ -180,3 +177,1 @@

The ImagJ-Macro is available [here](https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/ImageJ/Macro_create_maps.ijm).

@@ -9,6 +9,8 @@ pip

openpyxl
imagecodecs
albumentations>=1.0.0
natsort>=7.1.1
numba>=0.52.0
segmentation-models-pytorch>=0.2
opencv-python-headless<4.5.5,>=4.1.1
timm>=0.5.4
segmentation-models-pytorch-deepflash2

@@ -10,5 +10,7 @@ CONTRIBUTING.md

deepflash2/all.py
deepflash2/config.py
deepflash2/data.py
deepflash2/gt.py
deepflash2/gui.py
deepflash2/inference.py
deepflash2/learner.py

@@ -22,5 +24,4 @@ deepflash2/losses.py

deepflash2.egg-info/dependency_links.txt
deepflash2.egg-info/entry_points.txt
deepflash2.egg-info/not-zip-safe
deepflash2.egg-info/requires.txt
deepflash2.egg-info/top_level.txt

@@ -1,1 +0,1 @@

__version__ = "0.1.8"
__version__ = "0.2.0"

@@ -5,9 +5,8 @@ # AUTOGENERATED BY NBDEV! DO NOT EDIT!

index = {"Config": "00_learner.ipynb",
"energy_score": "00_learner.ipynb",
"EnsemblePredict": "00_learner.ipynb",
"EnsembleLearner": "00_learner.ipynb",
index = {"Config": "00_config.ipynb",
"ARCHITECTURES": "01_models.ipynb",
"ENCODERS": "01_models.ipynb",
"get_pretrained_options": "01_models.ipynb",
"smp.decoders.unet.decoder.UnetDecoder.forward": "01_models.ipynb",
"PATCH_UNET_DECODER": "01_models.ipynb",
"create_smp_model": "01_models.ipynb",

@@ -22,5 +21,18 @@ "save_smp_model": "01_models.ipynb",

"DeformationField": "02_data.ipynb",
"tiles_in_rectangles": "02_data.ipynb",
"BaseDataset": "02_data.ipynb",
"RandomTileDataset": "02_data.ipynb",
"TileDataset": "02_data.ipynb",
"EnsembleBase": "03_learner.ipynb",
"EnsembleLearner": "03_learner.ipynb",
"EnsemblePredictor": "03_learner.ipynb",
"torch_gaussian": "04_inference.ipynb",
"gaussian_kernel_2d": "04_inference.ipynb",
"epistemic_uncertainty": "04_inference.ipynb",
"aleatoric_uncertainty": "04_inference.ipynb",
"uncertainty": "04_inference.ipynb",
"get_in_slices_1d": "04_inference.ipynb",
"get_out_slices_1d": "04_inference.ipynb",
"TileModule": "04_inference.ipynb",
"InferenceEnsemble": "04_inference.ipynb",
"LOSSES": "05_losses.ipynb",

@@ -30,2 +42,3 @@ "FastaiLoss": "05_losses.ipynb",

"JointLoss": "05_losses.ipynb",
"Poly1CrossEntropyLoss": "05_losses.ipynb",
"get_loss": "05_losses.ipynb",

@@ -37,6 +50,7 @@ "unzip": "06_utils.ipynb",

"compose_albumentations": "06_utils.ipynb",
"ensemble_results": "06_utils.ipynb",
"clean_show": "06_utils.ipynb",
"plot_results": "06_utils.ipynb",
"Recorder.plot_metrics": "06_utils.ipynb",
"iou": "06_utils.ipynb",
"multiclass_dice_score": "06_utils.ipynb",
"binary_dice_score": "06_utils.ipynb",
"dice_score": "06_utils.ipynb",

@@ -54,9 +68,8 @@ "label_mask": "06_utils.ipynb",

"BaseTransform": "07_tta.ipynb",
"HorizontalFlip": "07_tta.ipynb",
"VerticalFlip": "07_tta.ipynb",
"Rotate90": "07_tta.ipynb",
"Chain": "07_tta.ipynb",
"Transformer": "07_tta.ipynb",
"Compose": "07_tta.ipynb",
"Merger": "07_tta.ipynb",
"HorizontalFlip": "07_tta.ipynb",
"VerticalFlip": "07_tta.ipynb",
"Rotate90": "07_tta.ipynb",
"GRID_COLS": "08_gui.ipynb",

@@ -100,10 +113,11 @@ "COLS_PRED_KEEP": "08_gui.ipynb",

"import_sitk": "09_gt.ipynb",
"staple": "09_gt.ipynb",
"staple_multi_label": "09_gt.ipynb",
"m_voting": "09_gt.ipynb",
"msk_show": "09_gt.ipynb",
"GTEstimator": "09_gt.ipynb"}
modules = ["learner.py",
modules = ["config.py",
"models.py",
"data.py",
"learner.py",
"inference.py",
"losses.py",

@@ -110,0 +124,0 @@ "utils.py",

@@ -0,1 +1,2 @@

from .config import *
from .learner import *

@@ -2,0 +3,0 @@ from .data import *

# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02_data.ipynb (unless otherwise specified).
__all__ = ['show', 'preprocess_mask', 'DeformationField', 'BaseDataset', 'RandomTileDataset', 'TileDataset']
__all__ = ['show', 'preprocess_mask', 'DeformationField', 'tiles_in_rectangles', 'BaseDataset', 'RandomTileDataset',
'TileDataset']

@@ -23,6 +24,9 @@ # Cell

from fastcore.all import *
from fastprogress import progress_bar
from .utils import clean_show
# Cell
def show(*obj, file_name=None, overlay=False, pred=False,
show_bbox=True, figsize=(10,10), cmap='binary_r', **kwargs):
def show(*obj, file_name=None, overlay=False, pred=False, num_classes=2,
show_bbox=False, figsize=(10,10), cmap='viridis', **kwargs):
"Show image, mask, and weight (optional)"

@@ -63,5 +67,2 @@ if len(obj)==3:

if cmap is None:
cmap = 'binary_r' if msk.max()==1 else cmap
# Weights preprocessing

@@ -78,8 +79,4 @@ if weight is not None:

# Plot img
img_ax.imshow(img, cmap=cmap)
if file_name is not None:
img_ax.set_title('Image {}'.format(file_name))
else:
img_ax.set_title('Image')
img_ax.set_axis_off()
img_title = f'Image {file_name}' if file_name is not None else 'Image'
clean_show(img_ax, img, img_title, cmap)

@@ -91,7 +88,7 @@ # Plot img and mask

img_l2o = label2rgb(label_image, image=img, bg_label=0, alpha=.8, image_alpha=1)
ax[1].set_title('Image + Mask (#ROIs: {})'.format(label_image.max()))
ax[1].imshow(img_l2o)
pred_title = 'Image + Mask (#ROIs: {})'.format(label_image.max())
clean_show(ax[1], img_l2o, pred_title, None)
else:
ax[1].imshow(msk, cmap=cmap)
ax[1].set_title('Mask')
vkwargs = {'vmin':0, 'vmax':num_classes-1}
clean_show(ax[1], msk, 'Mask', cmap, cbar='classes', ticks=num_classes, **vkwargs)
if show_bbox: ax[1].add_patch(copy(bbox))

@@ -135,3 +132,3 @@

# adapted from Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.
def preprocess_mask(clabels=None, instlabels=None, remove_connectivity=True, n_classes = 2):
def preprocess_mask(clabels=None, instlabels=None, remove_connectivity=True, num_classes = 2):
"Calculates the weights from the given mask (classlabels `clabels` or `instlabels`)."

@@ -170,3 +167,3 @@

# of that class, avoid overlapping instances
dil = cv2.morphologyEx(il, cv2.MORPH_CLOSE, kernel=np.ones((3,) * n_classes))
dil = cv2.morphologyEx(il, cv2.MORPH_CLOSE, kernel=np.ones((3,) * num_classes))
overlap_cand = np.unique(np.where(dil!=il, dil, 0))

@@ -176,3 +173,3 @@ labels[np.isin(il, overlap_cand, invert=True)] = c

for instance in overlap_cand[1:]:
objectMaskDil = cv2.dilate((labels == c).astype('uint8'), kernel=np.ones((3,) * n_classes),iterations = 1)
objectMaskDil = cv2.dilate((labels == c).astype('uint8'), kernel=np.ones((3,) * num_classes),iterations = 1)
labels[(instlabels == instance) & (objectMaskDil == 0)] = c

@@ -263,3 +260,3 @@ else:

def _read_img(path, **kwargs):
"Read image and normalize to 0-1 range"
"Read image"
if path.suffix == '.zarr':

@@ -269,4 +266,4 @@ img = zarr.convenience.open(path.as_posix())

img = imageio.imread(path, **kwargs)
if img.max()>1.:
img = img/np.iinfo(img.dtype).max
#if img.max()>1.:
# img = img/np.iinfo(img.dtype).max
if img.ndim == 2:

@@ -277,3 +274,3 @@ img = np.expand_dims(img, axis=2)

# Cell
def _read_msk(path, n_classes=2, instance_labels=False, remove_connectivity=True, **kwargs):
def _read_msk(path, num_classes=2, instance_labels=False, remove_connectivity=True, **kwargs):
"Read image and check classes"

@@ -285,6 +282,6 @@ if path.suffix == '.zarr':

if instance_labels:
msk = preprocess_mask(clabels=None, instlabels=msk, remove_connectivity=remove_connectivity, n_classes=n_classes)
msk = preprocess_mask(clabels=None, instlabels=msk, remove_connectivity=remove_connectivity, num_classes=num_classes)
else:
# handle binary labels that are scaled different from 0 and 1
if n_classes==2 and np.max(msk)>1 and len(np.unique(msk))==2:
if num_classes==2 and np.max(msk)>1 and len(np.unique(msk))==2:
msk = msk//np.max(msk)

@@ -296,3 +293,3 @@ # Remove channels if no extra information given

# Mask check
assert len(np.unique(msk))<=n_classes, f'Expected mask with {n_classes} classes but got mask with {len(np.unique(msk))} classes {np.unique(msk)} . Are you using instance labels?'
assert len(np.unique(msk))<=num_classes, f'Expected mask with {num_classes} classes but got mask with {len(np.unique(msk))} classes {np.unique(msk)} . Are you using instance labels?'
assert len(msk.shape)==2, 'Currently, only masks with a single channel are supported.'

@@ -302,10 +299,23 @@ return msk.astype('uint8')

# Cell
def tiles_in_rectangles(H, W, h, w):
'''Get smaller rectangles needed to fill the larger rectangle'''
n_H = math.ceil(float(H)/float(h))
n_W = math.ceil(float(W)/float(w))
return n_H*n_W
# Cell
class BaseDataset(Dataset):
def __init__(self, files, label_fn=None, instance_labels = False, n_classes=2, ignore={},remove_connectivity=True,stats=None,normalize=True,
def __init__(self, files, label_fn=None, instance_labels = False, num_classes=2, ignore={},remove_connectivity=True,
stats=None,normalize=True, use_zarr_data=True,
tile_shape=(512,512), padding=(0,0),preproc_dir=None, verbose=1, scale=1, pdf_reshape=512, use_preprocessed_labels=False, **kwargs):
store_attr('files, label_fn, instance_labels, n_classes, ignore, tile_shape, remove_connectivity, padding, preproc_dir, normalize, scale, pdf_reshape, use_preprocessed_labels')
self.c = n_classes
store_attr('files, label_fn, instance_labels, num_classes, ignore, tile_shape, remove_connectivity, padding, preproc_dir, stats, normalize, scale, pdf_reshape, use_preprocessed_labels')
self.c = num_classes
self.use_zarr_data=False
if self.normalize:
self.stats = stats or self.compute_stats()
if self.stats is None:
self.mean_sum, self.var_sum = 0., 0.
self.max_tile_count = 0
self.actual_tile_shape = (np.array(self.tile_shape)-np.array(self.padding))

@@ -315,7 +325,9 @@ if label_fn is not None:

root = zarr.group(store=self.preproc_dir, overwrite= not use_preprocessed_labels)
self.labels, self.pdfs = root.require_groups('labels','pdfs')
self._preproc(verbose)
self.data, self.labels, self.pdfs = root.require_groups('data', 'labels','pdfs')
self._preproc(use_zarr_data=use_zarr_data, verbose=verbose)
def read_img(self, *args, **kwargs):
return _read_img(*args, **kwargs)
def read_img(self, path, **kwargs):
if self.use_zarr_data: img = self.data[path.name]
else: img = _read_img(path, **kwargs)
return img

@@ -325,10 +337,18 @@ def read_mask(self, *args, **kwargs):

def _create_cdf(self, mask, ignore, fbr=None):
def _create_cdf(self, mask, ignore, sampling_weights=None, igonore_edges_pct=0):
'Creates a cumulated probability density function (CDF) for weighted sampling '
# Create probability density function
# Create mask
mask = mask[:]
fbr = fbr or np.sum(mask>0)/np.sum(mask==0)
pdf = (mask>0) + (mask==0) * fbr
if sampling_weights is None:
classes, counts = np.unique(mask, return_counts=True)
sampling_weights = {k:1-v/mask.size for k,v in zip(classes, counts)}
# Set pixel weights
pdf = np.zeros_like(mask, dtype=np.float32)
for k, v in sampling_weights.items():
pdf[mask==k] = v
# Set weight and sampling probability for ignored regions to 0

@@ -338,37 +358,64 @@ if ignore is not None:

#if igonore_edges:
# w = int(self.tile_shape[0]*0.25)
# pdf[:, :w] = pdf[:, -w:] = 0
# pdf[:w, :] = pdf[-w:, :] = 0
if igonore_edges_pct>0:
w = int(self.tile_shape[0]*igonore_edges_pct/2) #0.25
pdf[:, :w] = pdf[:, -w:] = 0
pdf[:w, :] = pdf[-w:, :] = 0
# Reshape
reshape_w = int((pdf.shape[1]/pdf.shape[0])*self.pdf_reshape)
pdf = cv2.resize(pdf, dsize=(reshape_w, self.pdf_reshape))
# Normalize pixel weights
pdf /= pdf.sum()
return np.cumsum(pdf/np.sum(pdf))
def _preproc_file(self, file):
"Preprocesses and saves labels (msk), weights, and pdf."
label_path = self.label_fn(file)
ign = self.ignore[file.name] if file.name in self.ignore else None
lbl = self.read_mask(label_path, n_classes=self.c, instance_labels=self.instance_labels, remove_connectivity=self.remove_connectivity)
self.labels[file.name] = lbl
self.pdfs[file.name] = self._create_cdf(lbl, ignore=ign)
def _preproc_file(self, file, use_zarr_data=True):
"Preprocesses and saves images, labels (msk), weights, and pdf."
def _preproc(self, verbose=0):
# Load and save image
img = self.read_img(file)
if self.stats is None:
self.mean_sum += img.mean((0,1))
self.var_sum += img.var((0,1))
self.max_tile_count = max(self.max_tile_count, tiles_in_rectangles(*img.shape[:2], *self.actual_tile_shape))
if use_zarr_data: self.data[file.name] = img
if self.label_fn is not None:
# Load and save image
label_path = self.label_fn(file)
ign = self.ignore[file.name] if file.name in self.ignore else None
lbl = self.read_mask(label_path, num_classes=self.c, instance_labels=self.instance_labels, remove_connectivity=self.remove_connectivity)
self.labels[file.name] = lbl
self.pdfs[file.name] = self._create_cdf(lbl, ignore=ign)
def _preproc(self, use_zarr_data=True, verbose=0):
using_cache = False
for f in self.files:
if verbose>0: print('Preprocessing data')
for f in progress_bar(self.files, leave=True if verbose>0 else False):
if self.use_preprocessed_labels:
try:
self.labels[f.name]
self.pdfs[f.name]
if not using_cache:
if verbose>0: print(f'Using preprocessed masks from {self.preproc_dir}')
using_cache = True
self.data[f.name]
if self.label_fn is not None:
self.labels[f.name]
self.pdfs[f.name]
if not using_cache:
#if verbose>0: print(f'Using preprocessed data from {self.preproc_dir}')
using_cache = True
except:
if verbose>0: print('Preprocessing', f.name)
self._preproc_file(f)
self._preproc_file(f, use_zarr_data=use_zarr_data)
else:
if verbose>0: print('Preprocessing', f.name)
self._preproc_file(f)
self._preproc_file(f, use_zarr_data=use_zarr_data)
self.use_zarr_data=use_zarr_data
if self.stats is None:
n = len(self.files)
#https://stackoverflow.com/questions/60101240/finding-mean-and-standard-deviation-across-image-channels-pytorch/60803379#60803379
self.stats = {'channel_means': self.mean_sum/n,
'channel_stds': np.sqrt(self.var_sum/n),
'max_tiles_per_image': self.max_tile_count}
print('Calculated stats', self.stats)
def get_data(self, files=None, max_n=None, mask=False):

@@ -401,22 +448,6 @@ if files is not None:

lbl = self.labels[f.name]
show(img, lbl, file_name=f.name, figsize=figsize, show_bbox=False, **kwargs)
show(img, lbl, file_name=f.name, figsize=figsize, show_bbox=False, num_classes=self.num_classes, **kwargs)
else:
show(img, file_name=f.name, figsize=figsize, show_bbox=False, **kwargs)
#https://stackoverflow.com/questions/60101240/finding-mean-and-standard-deviation-across-image-channels-pytorch/60803379#60803379
def compute_stats(self, max_samples=50):
"Computes mean and std from files"
print('Computing Stats...')
mean_sum, var_sum = 0., 0.
for i, f in enumerate(self.files, 1):
img = self.read_img(f)[:]
mean_sum += img.mean((0,1))
var_sum += img.var((0,1))
if i==max_samples:
print(f'Calculated stats from {i} files')
continue
self.mean = mean_sum/i
self.std = np.sqrt(var_sum/i)
return self.mean, self.std#*2
# Cell

@@ -435,7 +466,6 @@ class RandomTileDataset(BaseDataset):

if self.sample_mult is None:
#tile_shape = np.array(self.tile_shape)-np.array(self.padding)
#msk_shape = np.array(self.get_data(max_n=1)[0].shape[:-1])
#msk_shape = np.array(lbl.shape[-2:])
#sample_mult = int(np.product(np.floor(msk_shape/tile_shape)))
self.sample_mult = max(1, min_length//len(self.files))
self.sample_mult = max(int(self.stats['max_tiles_per_image']/self.scale**2),
min_length//len(self.files))

@@ -445,3 +475,5 @@ tfms = self.albumentations_tfms

tfms += [
A.Normalize(mean=self.stats[0], std=self.stats[1], max_pixel_value=1.)
A.Normalize(mean=self.stats['channel_means'],
std=self.stats['channel_stds'],
max_pixel_value=1.0)
]

@@ -491,5 +523,5 @@ self.tfms = A.Compose(tfms+[ToTensorV2()])

n_inp = 1
def __init__(self, *args, val_length=None, val_seed=42, is_zarr=False, shift=1., border_padding_factor=0.25, return_index=False, **kwargs):
def __init__(self, *args, val_length=None, val_seed=42, max_tile_shift=1., border_padding_factor=0.25, return_index=False, **kwargs):
super().__init__(*args, **kwargs)
self.shift = shift
self.max_tile_shift = max_tile_shift
self.bpf = border_padding_factor

@@ -509,13 +541,8 @@ self.return_index = return_index

tfms += [
#A.ToFloat(),
A.Normalize(mean=self.stats[0], std=self.stats[1], max_pixel_value=1.)
A.Normalize(mean=self.stats['channel_means'],
std=self.stats['channel_stds'],
max_pixel_value=1.0)
]
self.tfms = A.Compose(tfms+[ToTensorV2()])
if self.files[0].suffix == '.zarr' or is_zarr:
self.data = zarr.open_group(self.files[0].parent.as_posix(), mode='r')
is_zarr = True
else:
root = zarr.group(store=zarr.storage.TempStore(), overwrite=True)
self.data = root.create_group('data')

@@ -525,3 +552,2 @@ j = 0

img = self.read_img(file)
if not is_zarr: self.data[file.name] = img
# Tiling

@@ -531,3 +557,3 @@ data_shape = tuple(int(x//self.scale) for x in img.shape[:-1])

end_points = [(s - st) for s, st in zip(data_shape, start_points)]
n_points = [int((s+2*o*self.bpf)//(o*self.shift))+1 for s, o in zip(data_shape, self.output_shape)]
n_points = [int((s+2*o*self.bpf)//(o*self.max_tile_shift))+1 for s, o in zip(data_shape, self.output_shape)]
center_points = [np.linspace(st, e, num=n, endpoint=True, dtype=np.int64) for st, e, n in zip(start_points, end_points, n_points)]

@@ -571,3 +597,4 @@ for cx in center_points[1]:

img_path = self.files[self.image_indices[idx]]
img = self.data[img_path.name]
#img = self.data[img_path.name]
img = self.read_img(img_path)
centerPos = self.centers[idx]

@@ -574,0 +601,0 @@

# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/09_gt.ipynb (unless otherwise specified).
__all__ = ['import_sitk', 'staple', 'm_voting', 'msk_show', 'GTEstimator']
__all__ = ['import_sitk', 'staple_multi_label', 'm_voting', 'GTEstimator']

@@ -11,7 +11,8 @@ # Cell

from fastai.data.transforms import get_image_files
import matplotlib.pyplot as plt
from .data import _read_msk
from .learner import Config
from .utils import save_mask, dice_score, install_package, get_instance_segmentation_metrics
from .config import Config
from .utils import clean_show, save_mask, dice_score, install_package, get_instance_segmentation_metrics

@@ -30,10 +31,10 @@ # Cell

# Cell
def staple(segmentations, foregroundValue = 1, threshold = 0.5):
def staple_multi_label(segmentations, label_undecided_pixel=1):
'STAPLE: Simultaneous Truth and Performance Level Estimation with simple ITK'
sitk = import_sitk()
segmentations = [sitk.GetImageFromArray(x) for x in segmentations]
STAPLE_probabilities = sitk.STAPLE(segmentations)
STAPLE = STAPLE_probabilities > threshold
#STAPLE = sitk.GetArrayViewFromImage(STAPLE)
return sitk.GetArrayFromImage(STAPLE)
sitk_segmentations = [sitk.GetImageFromArray(x) for x in segmentations]
STAPLE = sitk.MultiLabelSTAPLEImageFilter()
STAPLE.SetLabelForUndecidedPixels(label_undecided_pixel)
msk = STAPLE.Execute(sitk_segmentations)
return sitk.GetArrayFromImage(msk)

@@ -49,18 +50,2 @@ # Cell

# Cell
from mpl_toolkits.axes_grid1 import make_axes_locatable
def msk_show(ax, msk, title, cbar=None, ticks=None, **kwargs):
img = ax.imshow(msk, **kwargs)
if cbar is not None:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
if cbar=='plot':
scale = ticks/(ticks+1)
cbr = plt.colorbar(img, cax=cax, ticks=[i*(scale)+(scale/2) for i in range(0, ticks+1)])
cbr.set_ticklabels([i for i in range(0, ticks+1)])
cbr.set_label('# of experts', rotation=270, labelpad=+15, fontsize="larger")
else: cax.set_axis_off()
ax.set_axis_off()
ax.set_title(title)
# Cell
class GTEstimator(GetAttr):

@@ -104,8 +89,9 @@ "Class for ground truth estimation"

fig, axs = plt.subplots(nrows=1, ncols=len(exps), figsize=figsize, **kwargs)
vkwargs = {'vmin':0, 'vmax':self.num_classes-1}
for i, exp in enumerate(exps):
try:
msk = _read_msk(self.mask_fn(exp,m), instance_labels=self.instance_labels)
except:
raise ValueError('Ground truth estimation currently only suppports two classes (binary masks or instance labels)')
msk_show(axs[i], msk, exp, cmap=self.cmap)
msk = _read_msk(self.mask_fn(exp,m), num_classes=self.num_classes, instance_labels=self.instance_labels)
if i == len(exps) - 1:
clean_show(axs[i], msk, exp, self.cmap, cbar='classes', ticks=self.num_classes, **vkwargs)
else:
clean_show(axs[i], msk, exp, self.cmap, **vkwargs)
fig.text(0, .5, m, ha='center', va='center', rotation=90)

@@ -121,10 +107,11 @@ plt.tight_layout()

for m, exps in progress_bar(self.masks.items()):
masks = [_read_msk(self.mask_fn(exp,m), instance_labels=self.instance_labels) for exp in exps]
masks = [_read_msk(self.mask_fn(exp,m), num_classes=self.num_classes, instance_labels=self.instance_labels) for exp in exps]
if method=='STAPLE':
ref = staple(masks, self.staple_fval, self.staple_thres)
#ref = staple(masks, self.staple_fval, self.staple_thres)
ref = staple_multi_label(masks, self.vote_undec)
elif method=='majority_voting':
ref = m_voting(masks, self.majority_vote_undec)
ref = m_voting(masks, self.vote_undec)
refs[m] = ref
#assert ref.mean() > 0, 'Please try again!'
df_tmp = pd.DataFrame({'method': method, 'file' : m, 'exp' : exps, 'dice_score': [dice_score(ref, msk) for msk in masks]})
df_tmp = pd.DataFrame({'method': method, 'file' : m, 'exp' : exps, 'dice_score': [dice_score(ref, msk, num_classes=self.num_classes) for msk in masks]})
if self.instance_segmentation_metrics:

@@ -164,9 +151,13 @@ mAP, AP = [],[]

for f in files:
fig, ax = plt.subplots(ncols=2, figsize=figsize, **kwargs)
# GT
msk_show(ax[0], self.gt[method][f], f'{method} (binary mask)', cbar='', cmap=self.cmap)
# Experts
masks = [_read_msk(self.mask_fn(exp,f), instance_labels=self.instance_labels) for exp in self.masks[f]]
masks_av = np.array(masks).sum(axis=0)#/len(masks)
msk_show(ax[1], masks_av, 'Expert Overlay', cbar='plot', ticks=len(masks), cmap=plt.cm.get_cmap(self.cmap, len(masks)+1))
if self.num_classes==2:
fig, ax = plt.subplots(ncols=2, figsize=figsize, **kwargs)
# GT
clean_show(ax[0], self.gt[method][f], f'{method} (binary mask)', cbar='', cmap=self.cmap)
# Experts
masks = [_read_msk(self.mask_fn(exp,f), num_classes=self.num_classes, instance_labels=self.instance_labels) for exp in self.masks[f]]
masks_av = np.array(masks).sum(axis=0)#/len(masks)
clean_show(ax[1], masks_av, 'Expert Overlay', cbar='experts', ticks=len(masks), cmap=plt.cm.get_cmap(self.cmap, len(masks)+1))
else:
fig, ax = plt.subplots(ncols=1, figsize=figsize, **kwargs)
clean_show(ax, self.gt[method][f], f'{method}', cbar='classes', cmap=self.cmap, ticks=self.num_classes)
# Results

@@ -176,7 +167,2 @@ metrics = ['dice_score', 'mean_average_precision', 'average_precision_at_iou_50'] if self.instance_segmentation_metrics else ['dice_score']

plt_df = self.df_res[self.df_res.file==f].set_index('exp')[metrics].append(av_df)
#plt_df.columns = [f'Similarity (Dice Score)']
#tbl = pd.plotting.table(ax[2], np.round(plt_df,3), loc='center', colWidths=[.5])
#tbl.set_fontsize(14)
#tbl.scale(1, 2)
#ax[2].set_axis_off()
fig.text(0, .5, f, ha='center', va='center', rotation=90)

@@ -183,0 +169,0 @@ plt.tight_layout()

@@ -1,27 +0,23 @@

# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_learner.ipynb (unless otherwise specified).
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03_learner.ipynb (unless otherwise specified).
__all__ = ['Config', 'energy_score', 'EnsemblePredict', 'EnsembleLearner']
__all__ = ['EnsembleBase', 'EnsembleLearner', 'EnsemblePredictor']
# Cell
import shutil, gc, joblib, json, zarr, numpy as np, pandas as pd
import torch
import time
import tifffile, cv2
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from dataclasses import dataclass, field, asdict
import zarr
import pandas as pd
import numpy as np
import cv2
import tifffile
from pathlib import Path
from typing import List, Union, Tuple
from sklearn import svm
from skimage.color import label2rgb
from sklearn.model_selection import KFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from scipy.ndimage.filters import gaussian_filter
from skimage.color import label2rgb
import matplotlib.pyplot as plt
from fastprogress import progress_bar
from fastcore.basics import patch, GetAttr
from fastcore.foundation import add_docs, L
from fastcore.basics import GetAttr
from fastcore.foundation import L
from fastai import optimizer
from fastai.torch_core import TensorImage
from fastai.learner import Learner

@@ -32,120 +28,19 @@ from fastai.callback.tracker import SaveModelCallback

from fastai.data.transforms import get_image_files, get_files
from fastai.vision.augment import Brightness, Contrast, Saturation
from fastai.losses import CrossEntropyLossFlat
from fastai.metrics import Dice, DiceMulti
from .config import Config
from .data import BaseDataset, TileDataset, RandomTileDataset
from .models import create_smp_model, save_smp_model, load_smp_model, run_cellpose
from .inference import InferenceEnsemble
from .losses import get_loss
from .models import create_smp_model, save_smp_model, load_smp_model, run_cellpose
from .data import TileDataset, RandomTileDataset, _read_img, _read_msk
from .utils import dice_score, plot_results, get_label_fn, calc_iterations, save_mask, save_unc, export_roi_set, get_instance_segmentation_metrics
from .utils import compose_albumentations as _compose_albumentations
import deepflash2.tta as tta
from .utils import dice_score, binary_dice_score, plot_results, get_label_fn, save_mask, save_unc, export_roi_set, get_instance_segmentation_metrics
from fastai.metrics import Dice, DiceMulti
# Cell
@dataclass
class Config:
"Config class for settings."
import matplotlib.pyplot as plt
import warnings
# Project
project_dir:str = '.'
#https://discuss.pytorch.org/t/slow-forward-on-traced-graph-on-cuda-2nd-iteration/118445/7
try: torch._C._jit_set_fusion_strategy([('STATIC', 0)])
except: torch._C._jit_set_bailout_depth(0)
# GT Estimation Settings
staple_thres:float = 0.5
staple_fval:int= 1
majority_vote_undec:int = 1
# Train General Settings
n_models:int = 5
max_splits:int=5
random_state:int = 42
# Pytorch Segmentation Model Settings
arch:str = 'Unet'
encoder_name:str = 'resnet34'
encoder_weights:str = 'imagenet'
# Train Data Settings
n_classes:int = 2
tile_shape:int = 512
instance_labels:bool = False
# Train Settings
base_lr:float = 0.001
batch_size:int = 4
weight_decay:float = 0.001
mixed_precision_training:bool = False
optim:str = 'Adam'
loss:str = 'CrossEntropyDiceLoss'
n_iter:int = 2500
sample_mult:int = 0
# Validation and Prediction Settings
tta:bool = True
border_padding_factor:float = 0.25
shift:float = 0.5
# Train Data Augmentation
gamma_limit_lower:int = 80
gamma_limit_upper:int = 120
CLAHE_clip_limit:float = 0.0
brightness_limit:float = 0.0
contrast_limit:float = 0.0
flip:bool = True
rot:int = 360
distort_limit:float = 0
# Loss Settings
mode:str = 'multiclass' #currently only tested for multiclass
loss_alpha:float = 0.5 # Twerksky/Focal loss
loss_beta:float = 0.5 # Twerksy Loss
loss_gamma:float = 2.0 # Focal loss
loss_smooth_factor:float = 0. #SoftCrossEntropyLoss
# Pred Settings
pred_tta:bool = True
min_pixel_export:int = 0
# Instance Segmentation Settings
cellpose_model:str='nuclei'
cellpose_diameter:int=0
cellpose_export_class:int=1
instance_segmentation_metrics:bool=False
# Folder Structure
gt_dir:str = 'GT_Estimation'
train_dir:str = 'Training'
pred_dir:str = 'Prediction'
ens_dir:str = 'models'
val_dir:str = 'valid'
@property
def albumentation_kwargs(self):
kwargs = ['gamma_limit_lower', 'gamma_limit_upper', 'CLAHE_clip_limit',
'brightness_limit', 'contrast_limit', 'distort_limit']
return dict(filter(lambda x: x[0] in kwargs, self.__dict__.items()))
@property
def svm_kwargs(self):
svm_vars = ['kernel', 'nu', 'gamma']
return dict(filter(lambda x: x[0] in svm_vars, self.__dict__.items()))
def save(self, path):
'Save configuration to path'
path = Path(path).with_suffix('.json')
with open(path, 'w') as config_file:
json.dump(asdict(self), config_file)
print(f'Saved current configuration to {path}.json')
return path
def load(self, path):
'Load configuration from path'
path = Path(path)
try:
with open(path) as config_file: c = json.load(config_file)
if not Path(c['project_dir']).is_dir(): c['project_dir']='deepflash2'
for k,v in c.items(): setattr(self, k, v)
print(f'Successsfully loaded configuration from {path}')
except:
print('Error! Select valid config file (.json)')
# Cell

@@ -164,170 +59,71 @@ _optim_dict = {

# Cell
# from https://github.com/MIC-DKFZ/nnUNet/blob/2fade8f32607220f8598544f0d5b5e5fa73768e5/nnunet/network_architecture/neural_network.py#L250
def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
tmp = np.zeros(patch_size)
center_coords = [i // 2 for i in patch_size]
sigmas = [i * sigma_scale for i in patch_size]
tmp[tuple(center_coords)] = 1
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
gaussian_importance_map = gaussian_importance_map.astype(np.float32)
class EnsembleBase(GetAttr):
_default = 'config'
def __init__(self, image_dir:str=None, mask_dir:str=None, files:List[Path]=None, label_fn:callable=None,
config:Config=None, path:Path=None, zarr_store:str=None):
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
gaussian_importance_map[gaussian_importance_map == 0] = np.min(
gaussian_importance_map[gaussian_importance_map != 0])
self.config = config or Config()
self.path = Path(path) if path is not None else Path('.')
self.label_fn = None
self.files = L()
return gaussian_importance_map
store = str(zarr_store) if zarr_store else zarr.storage.TempStore()
root = zarr.group(store=store, overwrite=False)
self.store = root.chunk_store.path
self.g_pred, self.g_smx, self.g_std = root.require_groups('preds', 'smxs', 'stds')
# Cell
def energy_score(x, T=1, dim=1):
'Return the energy score as proposed by Liu, Weitang, et al. (2020).'
return -(T*torch.logsumexp(x/T, dim=dim))
if any(v is not None for v in (image_dir, files)):
self.files = L(files) or self.get_images(image_dir)
# Cell
class EnsemblePredict():
'Class for prediction with multiple models'
def __init__(self, models_paths, zarr_store=None):
self.models_paths = models_paths
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.init_models()
if any(v is not None for v in (mask_dir, label_fn)):
assert hasattr(self, 'files'), 'image_dir or files must be provided'
self.label_fn = label_fn or self.get_label_fn(mask_dir)
self.check_label_fn()
# Init zarr storage
self.store = str(zarr_store) if zarr_store else zarr.storage.TempStore()
self.root = zarr.group(store=self.store)
self.g_smx = self.root.require_group('smx')
self.g_eng, self.g_std = None, None
def get_images(self, img_dir:str='images', img_path:Path=None) -> List[Path]:
'Returns list of image paths'
path = img_path or self.path/img_dir
files = get_image_files(path, recurse=False)
print(f'Found {len(files)} images in "{path}".')
if len(files)==0: warnings.warn('Please check your provided images and image folder')
return files
def init_models(self):
self.models = []
self.stats = None
for p in self.models_paths:
model, stats = load_smp_model(p)
if not self.stats: self.stats = stats
assert np.array_equal(stats, self.stats), 'Only models trained on the same stats are allowed.'
model.float()
model.eval()
model.to(self.device)
self.models.append(model)
def get_label_fn(self, msk_dir:str='masks', msk_path:Path=None):
'Returns label function to get paths of masks'
path = msk_path or self.path/msk_dir
return get_label_fn(self.files[0], path)
def predict(self,
ds,
use_tta=True,
bs=4,
use_gaussian=True,
sigma_scale=1./8,
uncertainty_estimates=True,
uncertainty_type = 'uncertainty',
energy_scores=False,
energy_T = 1.,
verbose=0):
def check_label_fn(self):
'Checks label function'
mask_check = [self.label_fn(x).exists() for x in self.files]
chk_str = f'Found {sum(mask_check)} corresponding masks.'
print(chk_str)
if len(self.files)!=sum(mask_check):
warnings.warn(f'Please check your images and masks (and folders).')
if verbose>0: print('Ensemble prediction with models:', self.models_paths)
def predict(self, arr:Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
'Get prediction for arr using inference_ensemble'
inp = torch.tensor(arr).float().to(self.device)
with torch.inference_mode():
preds = self.inference_ensemble(inp)
preds = [x.cpu().numpy() for x in preds]
return tuple(preds)
tfms = [tta.HorizontalFlip(),tta.VerticalFlip()] if use_tta else []
if verbose>0: print('Using Test-Time Augmentation with:', tfms)
def save_preds_zarr(self, f_name, pred, smx, std):
self.g_pred[f_name] = pred
self.g_smx[f_name] = smx
self.g_std[f_name] = std
dl = DataLoader(ds, bs, num_workers=0, shuffle=False, pin_memory=True)
def _create_ds(self, **kwargs):
self.ds = BaseDataset(self.files, label_fn=self.label_fn, instance_labels=self.instance_labels,
num_classes=self.num_classes, **kwargs)
# Create zero arrays
data_shape = ds.image_shapes[0]
softmax = np.zeros((*data_shape, ds.c), dtype='float32')
merge_map = np.zeros(data_shape, dtype='float32')
stdeviation = np.zeros(data_shape, dtype='float32') if uncertainty_estimates else None
energy = np.zeros(data_shape, dtype='float32') if energy_scores else None
# Define merge weights
if use_gaussian:
mw_numpy = _get_gaussian(ds.output_shape, sigma_scale)
else:
mw_numpy = np.ones(dl.output_shape)
mw = torch.from_numpy(mw_numpy).to(self.device)
# Loop over tiles (indices required!)
for tiles, idxs in iter(dl):
tiles = tiles.to(self.device)
smx_merger = tta.Merger()
if energy_scores:
energy_merger = tta.Merger()
# Loop over tt-augmentations
for t in tta.Compose(tfms):
aug_tiles = t.augment_image(tiles)
model_merger = tta.Merger()
if energy_scores: engergy_list = []
# Loop over models
for model in self.models:
with torch.inference_mode():
logits = model(aug_tiles)
logits = t.deaugment_mask(logits)
smx_merger.append(F.softmax(logits, dim=1))
if energy_scores:
energy_merger.append(-energy_score(logits, energy_T)) #negative energy score
out_list = []
# Apply gaussian weigthing
batch_smx = smx_merger.result()*mw.view(1,1,*mw.shape)
# Reshape and append to list
out_list.append([x for x in batch_smx.permute(0,2,3,1).cpu().numpy()])
if uncertainty_estimates:
batch_std = torch.mean(smx_merger.result(uncertainty_type), dim=1)*mw.view(1,*mw.shape)
out_list.append([x for x in batch_std.cpu().numpy()])
if energy_scores:
batch_energy = energy_merger.result()*mw.view(1,*mw.shape)
out_list.append([x for x in batch_energy.cpu().numpy()])
# Compose predictions
for preds in zip(*out_list, idxs):
if len(preds)==4: smx,std,eng,idx = preds
elif uncertainty_estimates: smx,std,idx = preds
elif energy_scores: smx,eng,idx = preds
else: smx, idx = preds
out_slice = ds.out_slices[idx]
in_slice = ds.in_slices[idx]
softmax[out_slice] += smx[in_slice]
merge_map[out_slice] += mw_numpy[in_slice]
if uncertainty_estimates:
stdeviation[out_slice] += std[in_slice]
if energy_scores:
energy[out_slice] += eng[in_slice]
# Normalize weighting
softmax /= merge_map[..., np.newaxis]
if uncertainty_estimates:
stdeviation /= merge_map
if energy_scores:
energy /= merge_map
return softmax, stdeviation, energy
def predict_images(self, image_list, ds_kwargs={}, verbose=1, **kwargs):
"Predict images in 'image_list' with kwargs and save to zarr"
for f in progress_bar(image_list, leave=False):
if verbose>0: print(f'Predicting {f.name}')
ds = TileDataset([f], stats=self.stats, return_index=True, **ds_kwargs)
softmax, stdeviation, energy = self.predict(ds, **kwargs)
# Save to zarr
self.g_smx[f.name] = softmax
if stdeviation is not None:
self.g_std = self.root.require_group('std')
self.g_std[f.name] = stdeviation
if energy is not None:
self.g_eng = self.root.require_group('energy')
self.g_eng[f.name] = energy
return self.g_smx, self.g_std, self.g_eng
# Cell
class EnsembleLearner(GetAttr):
_default = 'config'
def __init__(self, image_dir='images', mask_dir=None, config=None, path=None, ensemble_path=None, preproc_dir=None,
label_fn=None, metrics=None, cbs=None, ds_kwargs={}, dl_kwargs={}, model_kwargs={}, stats=None, files=None):
class EnsembleLearner(EnsembleBase):
"Meta class to training model ensembles with `n` models"
def __init__(self, *args, ensemble_path=None, preproc_dir=None, metrics=None, cbs=None,
ds_kwargs={}, dl_kwargs={}, model_kwargs={}, stats=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config or Config()
assert hasattr(self, 'label_fn'), 'mask_dir or label_fn must be provided.'
self.stats = stats

@@ -337,38 +133,15 @@ self.dl_kwargs = dl_kwargs

self.add_ds_kwargs = ds_kwargs
self.path = Path(path) if path is not None else Path('.')
default_metrics = [Dice()] if self.n_classes==2 else [DiceMulti()]
default_metrics = [Dice()] if self.num_classes==2 else [DiceMulti()]
self.metrics = metrics or default_metrics
self.loss_fn = self.get_loss()
self.cbs = cbs or [SaveModelCallback(monitor='dice' if self.n_classes==2 else 'dice_multi')] #ShowGraphCallback
self.cbs = cbs or [SaveModelCallback(monitor='dice' if self.num_classes==2 else 'dice_multi')] #ShowGraphCallback
self.ensemble_dir = ensemble_path or self.path/self.ens_dir
if ensemble_path is not None:
ensemble_path.mkdir(exist_ok=True, parents=True)
self.load_ensemble(path=ensemble_path)
self.load_models(path=ensemble_path)
else: self.models = {}
self.files = L(files) or get_image_files(self.path/image_dir, recurse=False)
assert len(self.files)>0, f'Found {len(self.files)} images in "{image_dir}". Please check your images and image folder'
if any([mask_dir, label_fn]):
if label_fn: self.label_fn = label_fn
else: self.label_fn = get_label_fn(self.files[0], self.path/mask_dir)
#Check if corresponding masks exist
mask_check = [self.label_fn(x).exists() for x in self.files]
chk_str = f'Found {len(self.files)} images in "{image_dir}" and {sum(mask_check)} masks in "{mask_dir}".'
assert len(self.files)==sum(mask_check) and len(self.files)>0, f'Please check your images and masks (and folders). {chk_str}'
print(chk_str)
else:
self.label_fn = label_fn
self.n_splits=min(len(self.files), self.max_splits)
self._set_splits()
self.ds = RandomTileDataset(self.files, label_fn=self.label_fn,
preproc_dir=preproc_dir,
instance_labels=self.instance_labels,
n_classes=self.n_classes,
stats=self.stats,
normalize = True,
sample_mult=self.sample_mult if self.sample_mult>0 else None,
verbose=0,
**self.add_ds_kwargs)
self._create_ds(stats=self.stats, preproc_dir=preproc_dir, verbose=1, **self.add_ds_kwargs)
self.stats = self.ds.stats

@@ -397,4 +170,5 @@ self.in_channels = self.ds.get_data(max_n=1)[0].shape[-1]

ds_kwargs['tile_shape']= (self.tile_shape,)*2
ds_kwargs['n_classes']= self.n_classes
ds_kwargs['shift']= self.shift
ds_kwargs['num_classes']= self.num_classes
ds_kwargs['max_tile_shift']= self.max_tile_shift
ds_kwargs['scale']= self.scale
ds_kwargs['border_padding_factor']= self.border_padding_factor

@@ -413,6 +187,8 @@ return ds_kwargs

ds_kwargs['tile_shape']= (self.tile_shape,)*2
ds_kwargs['n_classes']= self.n_classes
ds_kwargs['shift']= 1.
ds_kwargs['num_classes']= self.num_classes
ds_kwargs['scale']= self.scale
ds_kwargs['flip'] = self.flip
ds_kwargs['max_tile_shift']= 1.
ds_kwargs['border_padding_factor']= 0.
ds_kwargs['flip'] = self.flip
ds_kwargs['scale']= self.scale
ds_kwargs['albumentations_tfms'] = self._compose_albumentations(**self.albumentation_kwargs)

@@ -424,7 +200,8 @@ ds_kwargs['sample_mult'] = self.sample_mult if self.sample_mult>0 else None

def model_name(self):
return f'{self.arch}_{self.encoder_name}_{self.n_classes}classes'
encoder_name = self.encoder_name.replace('_', '-')
return f'{self.arch}_{encoder_name}_{self.num_classes}classes'
def get_loss(self):
kwargs = {'mode':self.mode,
'classes':[x for x in range(1, self.n_classes)],
'classes':[x for x in range(1, self.num_classes)],
'smooth_factor': self.loss_smooth_factor,

@@ -439,9 +216,8 @@ 'alpha':self.loss_alpha,

ds = []
ds.append(RandomTileDataset(files, label_fn=self.label_fn, **self.train_ds_kwargs))
ds.append(RandomTileDataset(files, label_fn=self.label_fn, **self.train_ds_kwargs, verbose=0))
if files_val:
ds.append(TileDataset(files_val, label_fn=self.label_fn, **self.train_ds_kwargs))
ds.append(TileDataset(files_val, label_fn=self.label_fn, **self.train_ds_kwargs, verbose=0))
else:
ds.append(ds[0])
dls = DataLoaders.from_dsets(*ds, bs=self.batch_size, pin_memory=True, **self.dl_kwargs)
if torch.cuda.is_available(): dls.cuda()
dls = DataLoaders.from_dsets(*ds, bs=self.batch_size, pin_memory=True, **self.dl_kwargs).to(self.device)
return dls

@@ -454,11 +230,12 @@

in_channels=self.in_channels,
classes=self.n_classes,
**self.model_kwargs)
if torch.cuda.is_available(): model.cuda()
classes=self.num_classes,
**self.model_kwargs).to(self.device)
return model
def fit(self, i, n_iter=None, base_lr=None, **kwargs):
n_iter = n_iter or self.n_iter
def fit(self, i, n_epochs=None, base_lr=None, **kwargs):
'Fit model number `i`'
n_epochs = n_epochs or self.n_epochs
base_lr = base_lr or self.base_lr
name = self.ensemble_dir/f'{self.model_name}-fold{i}.pth'
name = self.ensemble_dir/'single_models'/f'{self.model_name}-fold{i}.pth'
model = self._create_model()

@@ -480,5 +257,3 @@ files_train, files_val = self.splits[i]

print(f'Starting training for {name.name}')
epochs = calc_iterations(n_iter=n_iter,ds_length=len(dls.train_ds), bs=self.batch_size)
#self.learn.fit_one_cycle(epochs, lr_max)
self.learn.fine_tune(epochs, base_lr=base_lr)
self.learn.fine_tune(n_epochs, base_lr=base_lr)

@@ -491,8 +266,34 @@ print(f'Saving model at {name}')

def fit_ensemble(self, n_iter, skip=False, **kwargs):
del model
if torch.cuda.is_available(): torch.cuda.empty_cache()
def get_inference_ensemble(self, model_path=None):
model_paths = [model_path] if model_path is not None else self.models.values()
models = [load_smp_model(p)[0] for p in model_paths]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ensemble = InferenceEnsemble(models,
num_classes=self.num_classes,
in_channels=self.in_channels,
channel_means=self.stats['channel_means'].tolist(),
channel_stds=self.stats['channel_stds'].tolist(),
tile_shape=(self.tile_shape,)*2,
**self.inference_kwargs).to(self.device)
return torch.jit.script(ensemble)
def save_inference_ensemble(self):
ensemble = self.get_inference_ensemble()
ensemble_name = self.ensemble_dir/f'ensemble_{self.model_name}.pt'
print(f'Saving model at {ensemble_name}')
ensemble.save(ensemble_name)
def fit_ensemble(self, n_epochs=None, skip=False, save_inference_ensemble=True, **kwargs):
'Fit `i` models and `skip` existing'
for i in range(1, self.n_models+1):
if skip and (i in self.models): continue
self.fit(i, n_iter, **kwargs)
self.fit(i, n_epochs, **kwargs)
if save_inference_ensemble: self.save_inference_ensemble()
def set_n(self, n):
"Change to `n` models per ensemble"
for i in range(n, len(self.models)):

@@ -503,4 +304,7 @@ self.models.pop(i+1, None)

def get_valid_results(self, model_no=None, zarr_store=None, export_dir=None, filetype='.png', **kwargs):
"Validate models on validation data and save results"
res_list = []
model_list = self.models if not model_no else {k:v for k,v in self.models.items() if k==model_no}
model_dict = self.models if not model_no else {k:v for k,v in self.models.items() if k==model_no}
metric_name = 'dice_score' if self.num_classes==2 else 'average_dice_score'
if export_dir:

@@ -513,33 +317,31 @@ export_dir = Path(export_dir)

for i, model_path in model_list.items():
ep = EnsemblePredict(models_paths=[model_path], zarr_store=zarr_store)
for i, model_path in model_dict.items():
print(f'Validating model {i}.')
self.inference_ensemble = self.get_inference_ensemble(model_path=model_path)
_, files_val = self.splits[i]
g_smx, g_std, g_eng = ep.predict_images(files_val, bs=self.batch_size, ds_kwargs=self.pred_ds_kwargs, **kwargs)
del ep
torch.cuda.empty_cache()
chunk_store = g_smx.chunk_store.path
for j, f in enumerate(files_val):
msk = self.ds.get_data(f, mask=True)[0]
pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8')
m_dice = dice_score(msk, pred)
m_path = self.models[i].name
for j, f in progress_bar(enumerate(files_val), total=len(files_val)):
pred, smx, std = self.predict(self.ds.data[f.name][:])
self.save_preds_zarr(f.name, pred, smx, std)
msk = self.ds.labels[f.name][:] #.get_data(f, mask=True)[0])
m_dice = dice_score(msk, pred, num_classes=self.num_classes)
df_tmp = pd.Series({'file' : f.name,
'model' : m_path,
'model' : model_path,
'model_no' : i,
'dice_score': m_dice,
#'mean_energy': np.mean(g_eng[f.name][:][pred>0]),
'uncertainty_score': np.mean(g_std[f.name][:][pred>0]) if g_std is not None else None,
metric_name: m_dice,
'uncertainty_score': np.mean(std[pred>0]),
'image_path': f,
'mask_path': self.label_fn(f),
'softmax_path': f'{chunk_store}/{g_smx.path}/{f.name}',
'engergy_path': f'{chunk_store}/{g_eng.path}/{f.name}' if g_eng is not None else None,
'uncertainty_path': f'{chunk_store}/{g_std.path}/{f.name}' if g_std is not None else None})
'pred_path': f'{self.store}/{self.g_pred.path}/{f.name}',
'softmax_path': f'{self.store}/{self.g_smx.path}/{f.name}',
'uncertainty_path': f'{self.store}/{self.g_std.path}/{f.name}'})
res_list.append(df_tmp)
if export_dir:
save_mask(pred, pred_path/f'{df_tmp.file}_{df_tmp.model}_mask', filetype)
if g_std is not None:
save_unc(g_std[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.model}_uncertainty', filetype)
if g_eng is not None:
save_unc(g_eng[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.model}_energy', filetype)
save_mask(pred, pred_path/f'{df_tmp.file}_model{df_tmp.model_no}_mask', filetype)
save_unc(std, unc_path/f'{df_tmp.file}_model{df_tmp.model_no}_uncertainty', filetype)
del self.inference_ensemble
if torch.cuda.is_available(): torch.cuda.empty_cache()
self.df_val = pd.DataFrame(res_list)

@@ -551,3 +353,4 @@ if export_dir:

def show_valid_results(self, model_no=None, files=None, **kwargs):
def show_valid_results(self, model_no=None, files=None, metric_name='auto', **kwargs):
"Plot results of all or `file` validation images",
if self.df_val is None: self.get_valid_results(**kwargs)

@@ -557,13 +360,14 @@ df = self.df_val

if model_no is not None: df = df[df.model_no==model_no]
if metric_name=='auto': metric_name = 'dice_score' if self.num_classes==2 else 'average_dice_score'
for _, r in df.iterrows():
img = self.ds.get_data(r.image_path)[0][:]
msk = self.ds.get_data(r.image_path, mask=True)[0]
pred = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8')
std = zarr.load(r.uncertainty_path)
img = self.ds.data[r.file][:]
msk = self.ds.labels[r.file][:]
pred = self.g_pred[r.file][:]
std = self.g_std[r.file][:]
_d_model = f'Model {r.model_no}'
if self.tta: plot_results(img, msk, pred, std, df=r, model=_d_model)
else: plot_results(img, msk, pred, np.zeros_like(pred), df=r, model=_d_model)
plot_results(img, msk, pred, std, df=r, num_classes=self.num_classes, metric_name=metric_name, model=_d_model)
def load_ensemble(self, path=None):
path = path or self.ensemble_dir
def load_models(self, path=None):
"Get models saved at `path`"
path = path or self.ensemble_dir/'single_models'
models = sorted(get_files(path, extensions='.pth', recurse=False))

@@ -573,4 +377,4 @@ self.models = {}

for i, m in enumerate(models,1):
if i==0: self.n_classes = int(m.name.split('_')[2][0])
else: assert self.n_classes==int(m.name.split('_')[2][0]), 'Check models. Models are trained on different number of classes.'
if i==0: self.num_classes = int(m.name.split('_')[2][0])
else: assert self.num_classes==int(m.name.split('_')[2][0]), 'Check models. Models are trained on different number of classes.'
self.models[i] = m

@@ -587,9 +391,53 @@

def get_ensemble_results(self, files, zarr_store=None, export_dir=None, filetype='.png', **kwargs):
ep = EnsemblePredict(models_paths=self.models.values(), zarr_store=zarr_store)
g_smx, g_std, g_eng = ep.predict_images(files, bs=self.batch_size, ds_kwargs=self.pred_ds_kwargs, **kwargs)
chunk_store = g_smx.chunk_store.path
del ep
torch.cuda.empty_cache()
def lr_find(self, files=None, **kwargs):
"Wrapper function for learning rate finder"
files = files or self.files
dls = self._get_dls(files)
model = self._create_model()
learn = Learner(dls, model, metrics=self.metrics, wd=self.weight_decay, loss_func=self.loss_fn, opt_func=_optim_dict[self.optim])
if self.mixed_precision_training: learn.to_fp16()
sug_lrs = learn.lr_find(**kwargs)
return sug_lrs, learn.recorder
# Cell
class EnsemblePredictor(EnsembleBase):
def __init__(self, *args, ensemble_path:Path=None, **kwargs):
if ensemble_path is not None:
self.load_inference_ensemble(ensemble_path)
super().__init__(*args, **kwargs)
if hasattr(self, 'inference_ensemble'):
self.config.num_classes = self.inference_ensemble.num_classes
if hasattr(self, 'files'):
self._create_ds(stats={}, use_zarr_data = False, verbose=1)
self.ensemble_dir = self.path/self.ens_dir
#if ensemble_path is not None:
# self.load_inference_ensemble(ensemble_path)
def load_inference_ensemble(self, ensemble_path:Path=None):
"Load inference_ensemble from `self.ensemle_dir` or from `path`"
path = ensemble_path or self.ensemble_dir
if path.is_dir():
path_list = get_files(path, extensions='.pt', recurse=False)
if len(path_list)==0:
warnings.warn(f'No inference ensemble available at {path}. Did you train your ensemble correctly?')
return
path = path_list[0]
self.inference_ensemble_name = path.name
if hasattr(self, 'device'): self.inference_ensemble = torch.jit.load(path).to(self.device)
else: self.inference_ensemble = torch.jit.load(path)
print(f'Successfully loaded InferenceEnsemble from {path}')
def get_ensemble_results(self, file_list=None, export_dir=None, filetype='.png', **kwargs):
'Predict files in file_list using InferenceEnsemble'
if file_list is not None:
self.files = file_list
self._create_ds(stats={}, use_zarr_data = False, verbose=1)
if export_dir:

@@ -603,68 +451,78 @@ export_dir = Path(export_dir)

res_list = []
for f in files:
pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8')
for f in progress_bar(self.files):
img = self.ds.read_img(f)
pred, smx, std = self.predict(img)
self.save_preds_zarr(f.name, pred, smx, std)
df_tmp = pd.Series({'file' : f.name,
'ensemble' : self.model_name,
'n_models' : len(self.models),
#'mean_energy': np.mean(g_eng[f.name][:][pred>0]),
'uncertainty_score': np.mean(g_std[f.name][:][pred>0]) if g_std is not None else None,
'ensemble' : self.inference_ensemble_name,
'uncertainty_score': np.mean(std[pred>0]),
'image_path': f,
'softmax_path': f'{chunk_store}/{g_smx.path}/{f.name}',
'uncertainty_path': f'{chunk_store}/{g_std.path}/{f.name}' if g_std is not None else None,
'energy_path': f'{chunk_store}/{g_eng.path}/{f.name}' if g_eng is not None else None})
'pred_path': f'{self.store}/{self.g_pred.path}/{f.name}',
'softmax_path': f'{self.store}/{self.g_smx.path}/{f.name}',
'uncertainty_path': f'{self.store}/{self.g_std.path}/{f.name}'})
res_list.append(df_tmp)
if export_dir:
save_mask(pred, pred_path/f'{df_tmp.file}_{df_tmp.ensemble}_mask', filetype)
if g_std is not None:
save_unc(g_std[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.ensemble}_unc', filetype)
if g_eng is not None:
save_unc(g_eng[f.name][:], unc_path/f'{df_tmp.file}_{df_tmp.ensemble}_energy', filetype)
save_mask(pred, pred_path/f'{df_tmp.file}_mask', filetype)
save_unc(std, unc_path/f'{df_tmp.file}_unc', filetype)
self.df_ens = pd.DataFrame(res_list)
return g_smx, g_std, g_eng
return self.g_pred, self.g_smx, self.g_std
def score_ensemble_results(self, mask_dir=None, label_fn=None):
if mask_dir is not None and label_fn is None:
label_fn = get_label_fn(self.df_ens.image_path[0], self.path/mask_dir)
for i, r in self.df_ens.iterrows():
if label_fn is not None:
msk_path = self.label_fn(r.image_path)
msk = _read_msk(msk_path, n_classes=self.n_classes, instance_labels=self.instance_labels)
self.df_ens.loc[i, 'mask_path'] = msk_path
"Compare ensemble results to given segmentation masks."
if any(v is not None for v in (mask_dir, label_fn)):
self.label_fn = label_fn or self.get_label_fn(mask_dir)
self._create_ds(stats={}, use_zarr_data = False, verbose=1)
print('Calculating metrics')
for i, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)):
msk = self.ds.labels[r.file][:]
pred = self.g_pred[r.file][:]
if self.num_classes==2:
self.df_ens.loc[i, f'dice_score'] = binary_dice_score(msk, pred)
else:
msk = self.ds.labels[r.file][:]
pred = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8')
self.df_ens.loc[i, 'dice_score'] = dice_score(msk, pred)
for cl in range(self.num_classes):
msk_bin = msk==cl
pred_bin = pred==cl
if np.any([msk_bin, pred_bin]):
self.df_ens.loc[i, f'dice_score_class{cl}'] = binary_dice_score(msk_bin, pred_bin)
if self.num_classes>2:
self.df_ens['average_dice_score'] = self.df_ens[[col for col in self.df_ens if col.startswith('dice_score_class')]].mean(axis=1)
return self.df_ens
def show_ensemble_results(self, files=None, unc=True, unc_metric=None, metric_name='dice_score'):
def show_ensemble_results(self, files=None, unc=True, unc_metric=None, metric_name='auto'):
"Show result of ensemble or `model_no`"
assert self.df_ens is not None, "Please run `get_ensemble_results` first."
df = self.df_ens
if files is not None: df = df.reset_index().set_index('file', drop=False).loc[files]
if metric_name=='auto': metric_name = 'dice_score' if self.num_classes==2 else 'average_dice_score'
for _, r in df.iterrows():
imgs = []
imgs.append(_read_img(r.image_path)[:])
imgs.append(self.ds.read_img(r.image_path))
if metric_name in r.index:
try: msk = self.ds.labels[r.file][:]
except: msk = _read_msk(r.mask_path, n_classes=self.n_classes, instance_labels=self.instance_labels)
imgs.append(msk)
imgs.append(self.ds.labels[r.file][:])
hastarget=True
else:
hastarget=False
imgs.append(np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8'))
if unc: imgs.append(zarr.load(r.uncertainty_path))
plot_results(*imgs, df=r, hastarget=hastarget, metric_name=metric_name, unc_metric=unc_metric)
imgs.append(self.g_pred[r.file])
if unc: imgs.append(self.g_std[r.file])
plot_results(*imgs, df=r, hastarget=hastarget, num_classes=self.num_classes, metric_name=metric_name, unc_metric=unc_metric)
def get_cellpose_results(self, export_dir=None):
'Get instance segmentation results using the cellpose integration'
assert self.df_ens is not None, "Please run `get_ensemble_results` first."
cl = self.cellpose_export_class
assert cl<self.n_classes, f'{cl} not avaialable from {self.n_classes} classes'
assert cl<self.num_classes, f'{cl} not avaialable from {self.num_classes} classes'
smxs, preds = [], []
for _, r in self.df_ens.iterrows():
softmax = zarr.load(r.softmax_path)
smxs.append(softmax)
preds.append(np.argmax(softmax, axis=-1).astype('uint8'))
smxs.append(self.g_smx[r.file][:])
preds.append(self.g_pred[r.file][:])
probs = [x[...,cl] for x in smxs]
probs = [x[cl] for x in smxs]
masks = [x==cl for x in preds]

@@ -678,7 +536,6 @@ cp_masks = run_cellpose(probs, masks,

if export_dir:
export_dir = Path(export_dir)
cp_path = export_dir/'cellpose_masks'
cp_path.mkdir(parents=True, exist_ok=True)
export_dir = Path(export_dir)/'instance_labels'
export_dir.mkdir(parents=True, exist_ok=True)
for idx, r in self.df_ens.iterrows():
tifffile.imwrite(cp_path/f'{r.file}_class{cl}.tif', cp_masks[idx], compress=6)
tifffile.imwrite(export_dir/f'{r.file}_class{cl}.tif', cp_masks[idx], compress=6)

@@ -689,32 +546,31 @@ self.cellpose_masks = cp_masks

def score_cellpose_results(self, mask_dir=None, label_fn=None):
"Compare cellpose nstance segmentation results to given masks."
assert self.cellpose_masks is not None, 'Run get_cellpose_results() first'
if mask_dir is not None and label_fn is None:
label_fn = get_label_fn(self.df_ens.image_path[0], self.path/mask_dir)
if any(v is not None for v in (mask_dir, label_fn)):
self.label_fn = label_fn or self.get_label_fn(mask_dir)
self._create_ds(stats={}, use_zarr_data = False, verbose=1)
cl = self.cellpose_export_class
for i, r in self.df_ens.iterrows():
if label_fn is not None:
msk_path = self.label_fn(r.image_path)
msk = _read_msk(msk_path, n_classes=self.n_classes, instance_labels=self.instance_labels)
self.df_ens.loc[i, 'mask_path'] = msk_path
else:
msk = self.ds.labels[r.file][:]
_, msk = cv2.connectedComponents(msk, connectivity=4)
msk = self.ds.labels[r.file][:]==cl
_, msk = cv2.connectedComponents(msk.astype('uint8'), connectivity=4)
pred = self.cellpose_masks[i]
ap, tp, fp, fn = get_instance_segmentation_metrics(msk, pred, is_binary=False, min_pixel=self.min_pixel_export)
self.df_ens.loc[i, 'mean_average_precision'] = ap.mean()
self.df_ens.loc[i, 'average_precision_at_iou_50'] = ap[0]
self.df_ens.loc[i, f'mAP_class{cl}'] = ap.mean()
self.df_ens.loc[i, f'mAP_iou50_class{cl}'] = ap[0]
return self.df_ens
def show_cellpose_results(self, files=None, unc=True, unc_metric=None, metric_name='mean_average_precision'):
def show_cellpose_results(self, files=None, unc_metric=None, metric_name='auto'):
'Show instance segmentation results from cellpose predictions.'
assert self.df_ens is not None, "Please run `get_ensemble_results` first."
df = self.df_ens.reset_index()
if files is not None: df = df.set_index('file', drop=False).loc[files]
if metric_name=='auto': metric_name=f'mAP_class{self.cellpose_export_class}'
for _, r in df.iterrows():
imgs = []
imgs.append(_read_img(r.image_path)[:])
imgs = [self.ds.read_img(r.image_path)]
if metric_name in r.index:
try:
mask = self.ds.labels[idx][:]
except:
mask = _read_msk(r.mask_path, n_classes=self.n_classes, instance_labels=self.instance_labels)
_, comps = cv2.connectedComponents((mask==self.cellpose_export_class).astype('uint8'), connectivity=4)
mask = self.ds.labels[r.file][:]
mask = (mask==self.cellpose_export_class).astype('uint8')
_, comps = cv2.connectedComponents(mask, connectivity=4)
imgs.append(label2rgb(comps, bg_label=0))

@@ -724,16 +580,9 @@ hastarget=True

hastarget=False
imgs.append(label2rgb(self.cellpose_masks[r['index']], bg_label=0))
if unc: imgs.append(zarr.load(r.uncertainty_path))
plot_results(*imgs, df=r, hastarget=hastarget, metric_name=metric_name, unc_metric=unc_metric)
imgs.append(self.g_std[r.file])
plot_results(*imgs, df=r, hastarget=hastarget, num_classes=self.num_classes, instance_labels=True, metric_name=metric_name, unc_metric=unc_metric)
def lr_find(self, files=None, **kwargs):
files = files or self.files
dls = self._get_dls(files)
model = self._create_model()
learn = Learner(dls, model, metrics=self.metrics, wd=self.weight_decay, loss_func=self.loss_fn, opt_func=_optim_dict[self.optim])
if self.mixed_precision_training: learn.to_fp16()
sug_lrs = learn.lr_find(**kwargs)
return sug_lrs, learn.recorder
def export_imagej_rois(self, output_folder='ROI_sets', **kwargs):
'Export ImageJ ROI Sets to `ouput_folder`'
assert self.df_ens is not None, "Please run prediction first."

@@ -744,40 +593,13 @@

for idx, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)):
mask = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8')
uncertainty = zarr.load(r.uncertainty_path)
export_roi_set(mask, uncertainty, name=r.file, path=output_folder, ascending=False, **kwargs)
pred = self.g_pred[r.file][:]
uncertainty = self.g_std[r.file][:]
export_roi_set(pred, uncertainty, name=r.file, path=output_folder, ascending=False, **kwargs)
def export_cellpose_rois(self, output_folder='cellpose_ROI_sets', **kwargs):
'Export cellpose predictions to ImageJ ROI Sets in `ouput_folder`'
output_folder = Path(output_folder)
output_folder.mkdir(exist_ok=True, parents=True)
for idx, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)):
mask = self.cellpose_masks[idx]
uncertainty = zarr.load(r.uncertainty_path)
export_roi_set(mask, uncertainty, instance_labels=True, name=r.file, path=output_folder, ascending=False, **kwargs)
#def clear_tmp(self):
# try:
# shutil.rmtree('/tmp/*', ignore_errors=True)
# shutil.rmtree(self.path/'.tmp')
# print(f'Deleted temporary files from {self.path/".tmp"}')
# except: print(f'No temporary files to delete at {self.path/".tmp"}')
# Cell
add_docs(EnsembleLearner, "Meta class to train and predict model ensembles with `n` models",
fit="Fit model number `i`",
fit_ensemble="Fit `i` models and `skip` existing",
get_valid_results="Validate models on validation data and save results",
show_valid_results="Plot results of all or `file` validation images",
get_ensemble_results="Get models and ensemble results",
score_ensemble_results="Compare ensemble results to given segmentation masks.",
show_ensemble_results="Show result of ensemble or `model_no`",
load_ensemble="Get models saved at `path`",
get_cellpose_results='Get instance segmentation results using the cellpose integration',
score_cellpose_results="Compare cellpose nstance segmentation results to given masks.",
show_cellpose_results='Show instance segmentation results from cellpose predictions.',
get_loss="Get loss function from loss name (config)",
set_n="Change to `n` models per ensemble",
lr_find="Wrapper for learning rate finder",
export_imagej_rois='Export ImageJ ROI Sets to `ouput_folder`',
export_cellpose_rois='Export cellpose predictions to ImageJ ROI Sets in `ouput_folder`',
#clear_tmp="Clear directory with temporary files"
)
pred = self.cellpose_masks[idx]
uncertainty = self.g_std[r.file][:]
export_roi_set(pred, uncertainty, instance_labels=True, name=r.file, path=output_folder, ascending=False, **kwargs)
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05_losses.ipynb (unless otherwise specified).
__all__ = ['LOSSES', 'FastaiLoss', 'WeightedLoss', 'JointLoss', 'get_loss']
__all__ = ['LOSSES', 'FastaiLoss', 'WeightedLoss', 'JointLoss', 'Poly1CrossEntropyLoss', 'get_loss']

@@ -16,3 +16,3 @@ # Cell

# Cell
LOSSES = ['CrossEntropyLoss', 'DiceLoss', 'SoftCrossEntropyLoss', 'CrossEntropyDiceLoss', 'JaccardLoss', 'FocalLoss', 'LovaszLoss', 'TverskyLoss']
LOSSES = ['CrossEntropyLoss', 'DiceLoss', 'SoftCrossEntropyLoss', 'CrossEntropyDiceLoss', 'JaccardLoss', 'FocalLoss', 'LovaszLoss', 'TverskyLoss', 'Poly1CrossEntropyLoss']

@@ -63,2 +63,39 @@ # Cell

# Cell
class Poly1CrossEntropyLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
reduction: str = "mean"):
"""
Create instance of Poly1CrossEntropyLoss
:param num_classes:
:param epsilon:
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
"""
super(Poly1CrossEntropyLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.reduction = reduction
return
def forward(self, logits, labels):
"""
Forward pass
:param logits: tensor of shape [BNHW]
:param labels: tensor of shape [BHW]
:return: poly cross-entropy loss
"""
labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=logits.device,
dtype=logits.dtype)
labels_onehot = torch.moveaxis(labels_onehot, -1, 1)
pt = torch.sum(labels_onehot * F.softmax(logits, dim=1), dim=1)
CE = F.cross_entropy(input=logits, target=labels, reduction='none')
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1
# Cell
def get_loss(loss_name, mode='multiclass', classes=[1], smooth_factor=0., alpha=0.5, beta=0.5, gamma=2.0, reduction='mean', **kwargs):

@@ -96,2 +133,5 @@ 'Load losses from based on loss_name'

elif loss_name=="Poly1CrossEntropyLoss":
loss = Poly1CrossEntropyLoss(num_classes=max(classes)+1)
return loss
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_models.ipynb (unless otherwise specified).
__all__ = ['ARCHITECTURES', 'ENCODERS', 'get_pretrained_options', 'create_smp_model', 'save_smp_model',
'load_smp_model', 'check_cellpose_installation', 'get_diameters', 'run_cellpose']
__all__ = ['ARCHITECTURES', 'ENCODERS', 'get_pretrained_options', 'PATCH_UNET_DECODER', 'create_smp_model',
'save_smp_model', 'load_smp_model', 'check_cellpose_installation', 'get_diameters', 'run_cellpose']

@@ -31,2 +31,24 @@ # Cell

# Cell
PATCH_UNET_DECODER = False
@patch
def forward(self:smp.decoders.unet.decoder.UnetDecoder, *features):
features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
head = features[0]
skips = features[1:]
x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
if PATCH_UNET_DECODER:
x = torch.nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
return x
# Cell
def create_smp_model(arch, **kwargs):

@@ -36,4 +58,15 @@ 'Create segmentation_models_pytorch model'

assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'
is_convnext_encoder = kwargs['encoder_name'].startswith('tu-convnext')
assert not all((arch!="Unet", is_convnext_encoder)), 'ConvNeXt encoder can only be used with Unet'
if arch=="Unet": model = smp.Unet(**kwargs)
if arch=="Unet":
global PATCH_UNET_DECODER
if is_convnext_encoder:
kwargs['encoder_depth'] = 4
kwargs['decoder_channels'] = (256, 128, 64, 16)
PATCH_UNET_DECODER = True
else:
PATCH_UNET_DECODER = False
model = smp.Unet(**kwargs)
elif arch=="UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)

@@ -40,0 +73,0 @@ elif arch=="MAnet":model = smp.MAnet(**kwargs)

# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/07_tta.ipynb (unless otherwise specified).
__all__ = ['rot90', 'hflip', 'vflip', 'BaseTransform', 'Chain', 'Transformer', 'Compose', 'Merger', 'HorizontalFlip',
'VerticalFlip', 'Rotate90']
__all__ = ['rot90', 'hflip', 'vflip', 'BaseTransform', 'HorizontalFlip', 'VerticalFlip', 'Rotate90', 'Chain',
'Transformer', 'Compose']

@@ -9,3 +9,2 @@ # Cell

import itertools
from functools import partial
from typing import List, Optional, Union

@@ -15,11 +14,14 @@ from fastcore.foundation import store_attr

# Cell
def rot90(x, k=1):
@torch.jit.script
def rot90(x:torch.Tensor, k:int=1):
"rotate batch of images by 90 degrees k times"
return torch.rot90(x, k, (2, 3))
def hflip(x):
@torch.jit.script
def hflip(x:torch.Tensor):
"flip batch of images horizontally"
return x.flip(3)
def vflip(x):
@torch.jit.script
def vflip(x:torch.Tensor):
"flip batch of images vertically"

@@ -29,83 +31,9 @@ return x.flip(2)

# Cell
class BaseTransform:
class BaseTransform(torch.nn.Module):
identity_param = None
def __init__(self, pname: str, params: Union[list, tuple]): store_attr()
class Chain:
def __init__(self, functions: List[callable]):
self.functions = functions or []
def __call__(self, x):
for f in self.functions:
x = f(x)
return x
class Transformer:
def __init__(self, image_pipeline: Chain, mask_pipeline: Chain):
def __init__(self, pname: str, params: Union[list, tuple]):
super(BaseTransform, self).__init__()
store_attr()
def augment_image(self, image):
return self.image_pipeline(image)
def deaugment_mask(self, mask):
return self.mask_pipeline(mask)
class Compose:
def __init__(self, aug_transforms: List[BaseTransform]):
store_attr()
self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms]))
self.deaug_transforms = aug_transforms[::-1]
self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters]
def __iter__(self) -> Transformer:
for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters):
image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p})
for t, p in zip(self.aug_transforms, aug_params)])
mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p})
for t, p in zip(self.deaug_transforms, deaug_params)])
yield Transformer(image_pipeline=image_aug_chain, mask_pipeline=mask_deaug_chain)
def __len__(self) -> int:
return len(self.aug_transform_parameters)
# Cell
import matplotlib.pyplot as plt
class Merger:
def __init__(self):
self.output = []
def append(self, x):
self.output.append(torch.as_tensor(x))
def result(self, type='mean'):
s = torch.stack(self.output)
if type == 'max':
result = torch.max(s, dim=0)[0]
elif type == 'mean':
result = torch.mean(s, dim=0)
elif type == 'std':
result = torch.std(s, dim=0)
elif type == 'uncertainty':
# adapted from https://github.com/ykwon0407/UQ_BNN/blob/master/retina/utils.py
aleatoric_uncertainty = torch.mean(s * (1 - s), dim=0)
epistemic_uncertainty = torch.mean(s**2, dim=0) - torch.mean(s, dim=0)**2
result = epistemic_uncertainty + aleatoric_uncertainty
elif type == 'aleatoric_uncertainty':
result = torch.mean(s * (1 - s), dim=0)
elif type == 'epistemic_uncertainty':
result = torch.mean(s**2, dim=0) - torch.mean(s, dim=0)**2
elif type == 'entropy':
result = -torch.sum(s * torch.log(s), dim=0)
else:
raise ValueError('Not correct merge type `{}`.'.format(self.type))
return result
# Cell
class HorizontalFlip(BaseTransform):

@@ -115,12 +43,8 @@ "Flip images horizontally (left->right)"

def __init__(self):
super().__init__("apply", [False, True])
super(HorizontalFlip, self).__init__("apply", [0, 1])
def apply_aug_image(self, image, apply=False, **kwargs):
if apply: image = hflip(image)
return image
def forward(self, x:torch.Tensor, apply:int=0, deaug:bool=False):
if apply==1: x = hflip(x)
return x
def apply_deaug_mask(self, mask, apply=False, **kwargs):
if apply: mask = hflip(mask)
return mask
# Cell

@@ -131,12 +55,9 @@ class VerticalFlip(BaseTransform):

def __init__(self):
super().__init__("apply", [False, True])
super().__init__("apply", [0, 1])
def apply_aug_image(self, image, apply=False, **kwargs):
if apply: image = vflip(image)
return image
def forward(self, x:torch.Tensor, apply:int=0, deaug:bool=False):
if apply==1:
x = vflip(x)
return x
def apply_deaug_mask(self, mask, apply=False, **kwargs):
if apply: mask = vflip(mask)
return mask
# Cell

@@ -147,11 +68,47 @@ class Rotate90(BaseTransform):

def __init__(self, angles: List[int]):
super().__init__("angle", angles)
if self.identity_param not in angles:
angles = [self.identity_param] + list(angles)
super().__init__("angle", angles)
def apply_aug_image(self, image, angle=0, **kwargs):
@torch.jit.export
def apply_aug_image(self, image:torch.Tensor, angle:int=0): #, **kwargs
k = angle // 90 if angle >= 0 else (angle + 360) // 90
#k = torch.div(angle, 90, rounding_mode='trunc') if angle >= 0 else torch.div((angle + 360), 90, rounding_mode='trunc')
return rot90(image, k)
def apply_deaug_mask(self, mask, angle=0, **kwargs):
return self.apply_aug_image(mask, -angle)
def forward(self, x:torch.Tensor, angle:int=0, deaug:bool=False):
return self.apply_aug_image(x, angle=-angle if deaug else angle)
# Cell
class Chain(torch.nn.Module):
def __init__(self, transforms: List[BaseTransform]):
super().__init__()
self.transforms = torch.nn.ModuleList(transforms)
def forward(self, x, args:List[int], deaug:bool=False):
for i, t in enumerate(self.transforms):
x = t(x, args[i], deaug)
return x
# Cell
class Transformer(torch.nn.Module):
def __init__(self, transforms: List[BaseTransform], args:List[int]):
super(Transformer, self).__init__()
self.aug_pipeline = Chain(transforms)
self.deaug_pipeline = Chain(transforms[::-1])
self.args = args
@torch.jit.export
def augment(self, image:torch.Tensor):
return self.aug_pipeline(image, self.args, deaug=False)
@torch.jit.export
def deaugment(self, mask:torch.Tensor):
return self.deaug_pipeline(mask, self.args[::-1], deaug=True)
# Cell
class Compose(torch.nn.Module):
def __init__(self, aug_transforms: List[BaseTransform]):
super(Compose, self).__init__()
self.transform_parameters = list(itertools.product(*[t.params for t in aug_transforms]))
self.items = torch.nn.ModuleList([Transformer(aug_transforms, args) for args in self.transform_parameters])
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/06_utils.ipynb (unless otherwise specified).
__all__ = ['unzip', 'download_sample_data', 'install_package', 'import_package', 'compose_albumentations',
'ensemble_results', 'plot_results', 'iou', 'dice_score', 'label_mask', 'get_instance_segmentation_metrics',
'export_roi_set', 'calc_iterations', 'get_label_fn', 'save_mask', 'save_unc']
__all__ = ['unzip', 'download_sample_data', 'install_package', 'import_package', 'compose_albumentations', 'clean_show',
'plot_results', 'multiclass_dice_score', 'binary_dice_score', 'dice_score', 'label_mask',
'get_instance_segmentation_metrics', 'export_roi_set', 'calc_iterations', 'get_label_fn', 'save_mask',
'save_unc']

@@ -11,2 +12,3 @@ # Cell

from pathlib import Path
from scipy import ndimage

@@ -19,5 +21,12 @@ from scipy.spatial.distance import jaccard

from scipy.optimize import linear_sum_assignment
from sklearn.metrics import jaccard_score
from sklearn.metrics import jaccard_score, multilabel_confusion_matrix
from sklearn.metrics._classification import _prf_divide
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import albumentations as A
from fastcore.foundation import patch

@@ -85,16 +94,41 @@ from fastcore.meta import delegates

# Cell
def ensemble_results(res_dict, file, std=False):
"Combines single model predictions."
idx = 2 if std else 0
a = [np.array(res_dict[(mod, f)][idx]) for mod, f in res_dict if f==file]
a = np.mean(a, axis=0)
if std:
a = a[...,0]
def clean_show(ax, msk, title, cmap, cbar=None, ticks=None, **kwargs):
img = ax.imshow(msk, cmap=cmap, **kwargs)
#if cbar is not None:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
if cbar=='experts':
scale = ticks/(ticks+1)
cbr = plt.colorbar(img, cax=cax, ticks=[i*(scale)+(scale/2) for i in range(ticks+1)])
cbr.set_ticklabels([i for i in range(ticks+1)])
cbr.set_label('# of experts', rotation=270, labelpad=+15, fontsize="larger")
elif cbar=='classes':
scale = ticks/(ticks)
bounds = [i for i in range(ticks+1)]
cmap = plt.cm.get_cmap(cmap)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
cbr = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
cax=cax, ticks=[i*(scale)+(scale/2) for i in range(ticks)])
cbr.set_ticklabels([i for i in range(ticks)])
cbr.set_label('Classes', rotation=270, labelpad=+15, fontsize="larger")
elif cbar=='uncertainty':
cbr = plt.colorbar(img, cax=cax)
#cbr.set_ticklabels([i for i in range(ticks)])
#cbr.set_label('Uncertainty', rotation=270, labelpad=+15, fontsize="larger")
else:
a = np.argmax(a, axis=-1)
return a
cbr = plt.colorbar(img, cax=cax)
cax.set_axis_off()
cbr.remove()
ax.set_axis_off()
ax.set_title(title)
# Cell
def plot_results(*args, df, hastarget=False, model=None, metric_name='dice_score', unc_metric=None, figsize=(20, 20), **kwargs):
def plot_results(*args, df, hastarget=False, num_classes=2, model=None, instance_labels=False,
metric_name='dice_score', unc_metric=None, figsize=(20, 20), msk_cmap='viridis', **kwargs):
"Plot images, (masks), predictions and uncertainties side-by-side."
vkwargs = {'vmin':0, 'vmax':num_classes-1} if not instance_labels else {}
unc_vkwargs = {'vmin':0, 'vmax':1}
class_kwargs = {'cbar':'classes', 'ticks':num_classes} if not instance_labels else {}
if len(args)==4:

@@ -113,35 +147,20 @@ img, msk, pred, pred_std = args

img=img[...,0]
axs[0].imshow(img)
axs[0].set_axis_off()
axs[0].set_title(f'File {df.file}')
clean_show(axs[0], img, f'File {df.file}', None)
unc_title = f'Uncertainty \n {unc_metric}: {df[unc_metric]:.3f}' if unc_metric else 'Uncertainty'
pred_title = 'Prediction' if model is None else f'Prediction {model}'
if len(args)==4:
axs[1].imshow(msk)
axs[1].set_axis_off()
axs[1].set_title('Target')
axs[2].imshow(pred)
axs[2].set_axis_off()
axs[2].set_title(f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}')
axs[3].imshow(pred_std)
axs[3].set_axis_off()
axs[3].set_title(unc_title)
clean_show(axs[1], msk, 'Target', msk_cmap, cbar='classes', ticks=num_classes, **vkwargs, **kwargs)
pred_title = f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}'
clean_show(axs[2], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs)
clean_show(axs[3], pred_std, unc_title, 'hot', cbar='uncertainty', **unc_vkwargs, **kwargs)
elif len(args)==3 and not hastarget:
axs[1].imshow(pred)
axs[1].set_axis_off()
axs[1].set_title(pred_title)
axs[2].imshow(pred_std)
axs[2].set_axis_off()
axs[2].set_title(unc_title)
clean_show(axs[1], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs)
clean_show(axs[2], pred_std, unc_title, 'hot', cbar='uncertainty', **unc_vkwargs, **kwargs)
elif len(args)==3:
axs[1].imshow(msk)
axs[1].set_axis_off()
axs[1].set_title('Target')
axs[2].imshow(pred)
axs[2].set_axis_off()
axs[2].set_title(f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}')
clean_show(axs[1], msk, 'Target', msk_cmap, cbar='classes', ticks=num_classes,**vkwargs, **kwargs)
pred_title = f'{pred_title} \n {metric_name}: {df[metric_name]:.2f}'
clean_show(axs[2], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs)
elif len(args)==2:
axs[1].imshow(pred)
axs[1].set_axis_off()
axs[1].set_title(pred_title)
clean_show(axs[2], pred, pred_title, msk_cmap, **class_kwargs, **vkwargs, **kwargs)
plt.show()

@@ -172,21 +191,44 @@

# Cell
def iou(a,b,threshold=0.5, average='macro', **kwargs):
'''Computes the Intersection-Over-Union metric.'''
a = np.array(a).flatten()
b = np.array(b).flatten()
if a.max()>1 or b.max()>1:
return jaccard_score(a, b, average=average, **kwargs)
else:
a = np.array(a) > threshold
b = np.array(b) > threshold
overlap = a*b # Logical AND
union = a+b # Logical OR
return np.divide(np.count_nonzero(overlap),np.count_nonzero(union))
def multiclass_dice_score(y_true, y_pred, average='macro', **kwargs):
'''Computes the Sørensen–Dice coefficient for multiclass segmentations.'''
# Cell
def dice_score(*args, **kwargs):
'''Computes the Dice coefficient metric.'''
iou_score = iou(*args, **kwargs)
average_options = (None, 'micro', 'macro')
if average not in average_options:
raise ValueError('average has to be one of ' + str(average_options))
MCM = multilabel_confusion_matrix(y_true, y_pred, **kwargs)
numerator = 2*MCM[:, 1, 1]
denominator = 2*MCM[:, 1, 1] + MCM[:, 0, 1] + MCM[:, 1, 0]
if average == 'micro':
numerator = np.array([numerator.sum()])
denominator = np.array([denominator.sum()])
dice_scores = _prf_divide(numerator, denominator, '', None, None, '')
if average is None:
return dice_scores
return np.average(dice_scores)
def binary_dice_score(y_true, y_pred):
'''Compute the Sørensen–Dice coefficient for binary segmentations.'''
overlap = y_true*y_pred # Logical AND
union = y_true+y_pred # Logical OR
iou_score = np.divide(np.count_nonzero(overlap),np.count_nonzero(union)) #
return 2*iou_score/(iou_score+1)
def dice_score(y_true, y_pred, average='macro', num_classes=2, **kwargs):
'''Computes the Sørensen–Dice coefficient.'''
y_true = np.array(y_true).flatten()
y_pred = np.array(y_pred).flatten()
if y_true.max()>1 or y_pred.max()>1 or num_classes>2:
labels = [i for i in range(num_classes)]
return multiclass_dice_score(y_true, y_pred, average=average, labels=labels, **kwargs)
else:
return binary_dice_score(y_true, y_pred)
# Cell

@@ -193,0 +235,0 @@ def label_mask(mask, threshold=0.5, connectivity=4, min_pixel=0, do_watershed=False, exclude_border=False):

Metadata-Version: 2.1
Name: deepflash2
Version: 0.1.8
Version: 0.2.0
Summary: A Deep learning pipeline for segmentation of fluorescent labels in microscopy images

@@ -10,3 +10,2 @@ Home-page: https://github.com/matjesg/deepflash2

Keywords: unet,deep learning,semantic segmentation,microscopy,fluorescent labels
Platform: UNKNOWN
Classifier: Development Status :: 3 - Alpha

@@ -16,10 +15,8 @@ Classifier: Intended Audience :: Developers

Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Requires-Python: >=3.6
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
# Welcome to

@@ -180,3 +177,1 @@

The ImagJ-Macro is available [here](https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/ImageJ/Macro_create_maps.ijm).

@@ -1,2 +0,1 @@

# Welcome to

@@ -3,0 +2,0 @@

@@ -11,4 +11,4 @@ [DEFAULT]

branch = master
version = 0.1.8
min_python = 3.6
version = 0.2.0
min_python = 3.7
audience = Developers

@@ -19,5 +19,5 @@ language = English

status = 2
requirements = fastai>=2.1.7 zarr>=2.0 scikit-image imageio ipywidgets openpyxl albumentations>=1.0.0 natsort>=7.1.1 numba>=0.52.0 segmentation-models-pytorch>=0.2 opencv-python-headless>=4.1.1,<4.5.5
#pip_requirements =
#conda_requirements =
requirements = fastai>=2.1.7 zarr>=2.0 scikit-image imageio ipywidgets openpyxl imagecodecs albumentations>=1.0.0 natsort>=7.1.1 numba>=0.52.0 opencv-python-headless>=4.1.1,<4.5.5 timm>=0.5.4
pip_requirements = segmentation-models-pytorch-deepflash2
conda_requirements = segmentation_models_pytorch>=0.3.0
nbs_path = nbs

@@ -24,0 +24,0 @@ doc_path = docs

Sorry, the diff of this file is too big to display