Care restoration (prediction)#
In this notebook we use the trained CARE network to restore an image
from skimage.io import imread, imsave
from tnia.plotting.projections import show_xyz_slice, show_xyz_max
from csbdeep.models import Config, CARE
from csbdeep.utils import normalize
from skimage.transform import resize
import sys
sys.path.append('../')
import decon_helper as dh
import os
Load an input image#
In this cell we load an input image. Note that the image is bigger than the patch size we trained on. The CSBDeep framework will take care of applying the network in chunks.
input_names = [r'../../data/deep learning testing/inputs/nuclei.tif',
r'../../data/deep learning testing/inputs/spheres2.tif',
r'../../data/deep learning testing/inputs/spheres3.tif']
inputs=[]
minmaxes=[]
pmin = 0
pmax = 100
for input_name in input_names:
input = imread(input_name)
nmin=input.min()
nmax=input.max()
# put min and max in dictionary
minmaxes.append({'min':nmin, 'max':nmax})
input = normalize(input, pmin, pmax)
#input = (input.astype('float32')-nmin) / (nmax - nmin)
inputs.append(input)
dh.show_xyz_slice(input, 'tnia')
Load the model#
Load the model that we trained in the previous notebook
# load the model using the name of the model and the path to the model
#model_name = 'cytopacq'
model_name = 'spheres'
#model_name = 'combined'
model = CARE(None, model_name, basedir='../../models')
Loading network weights from 'weights_best.h5'.
Apply the model#
restorations = []
for input in inputs:
restored = model.predict(input, axes='ZYX', n_tiles=(1, 2, 2), normalizer=None)
restorations.append(restored)
1/1 [==============================] - 0s 325ms/step
1/1 [==============================] - 0s 243ms/step
25%|██▌ | 1/4 [00:00<?, ?it/s]
1/1 [==============================] - 0s 42ms/step
50%|█████ | 2/4 [00:00<00:00, 5.99it/s]
1/1 [==============================] - 0s 42ms/step
75%|███████▌ | 3/4 [00:00<00:00, 4.16it/s]
1/1 [==============================] - 0s 41ms/step
100%|██████████| 4/4 [00:01<00:00, 3.91it/s]
1/1 [==============================] - 0s 44ms/step
25%|██▌ | 1/4 [00:00<?, ?it/s]
1/1 [==============================] - 0s 50ms/step
50%|█████ | 2/4 [00:00<00:00, 5.88it/s]
1/1 [==============================] - 0s 42ms/step
75%|███████▌ | 3/4 [00:00<00:00, 4.16it/s]
1/1 [==============================] - 0s 40ms/step
100%|██████████| 4/4 [00:01<00:00, 3.90it/s]
for restored in restorations:
dh.show_xyz_slice(restored, 'tnia')
import napari
import numpy as np
n=2
decon = imread(r'../../data/deep learning testing/deconvolutions/spheres3.tif')
viewer = napari.Viewer()
viewer.add_image(inputs[n], name='input')
viewer.add_image(np.clip(restorations[n],0, restorations[n].max()), name='restored')
viewer.add_image(decon, name='deconvolution')
<Image layer 'deconvolution' at 0x1ffd94bce50>
print(inputs[n].min(), inputs[n].max(), inputs[n].mean())
print(restorations[n].min(), restorations[n].max(), restorations[n].mean())
0.0 1.0 0.012505261
-0.22564787 1.2641255 0.005944527
Plot line profiles to see if relative intensities were preserved#
This is tricky, because we have to ‘denormalize’ the output from the neural network.
import matplotlib.pyplot as plt
nmax=minmaxes[n]['max']
nmin=minmaxes[n]['min']
phantom = imread('../../data/deep learning testing/ground truth/spheres4.tif')
im = inputs[n]#*(nmax-nmin)+nmin
decon = (restorations[n])#*(nmax-nmin)#*inputs[n].sum()/restorations[n].sum()
s=im.shape
fig, ax = plt.subplots(figsize=(18,14))
line_z=int(s[0]//2)
line_y=int(7*s[1]//8)
line=im[line_z,line_y,:]
ax.plot(line, label = 'image')
line=decon[line_z,line_y,:]
ax.plot(line, label='care')
#line=phantom[line_z,line_y,:]
#ax.plot(line, label='truth')
ax.legend()
<matplotlib.legend.Legend at 0x2aefb3a3160>