Care restoration (prediction)#

In this notebook we use the trained CARE network to restore an image

from skimage.io import imread, imsave
from tnia.plotting.projections import show_xyz_slice, show_xyz_max
from csbdeep.models import Config, CARE
from csbdeep.utils import normalize
from skimage.transform import resize
import sys
sys.path.append('../')
import decon_helper as dh
import os

Load an input image#

In this cell we load an input image. Note that the image is bigger than the patch size we trained on. The CSBDeep framework will take care of applying the network in chunks.

input_names = [r'../../data/deep learning testing/inputs/nuclei.tif',
              r'../../data/deep learning testing/inputs/spheres2.tif',
              r'../../data/deep learning testing/inputs/spheres3.tif']
inputs=[]
minmaxes=[]

pmin = 0
pmax = 100

for input_name in input_names:
    input = imread(input_name)
    nmin=input.min()
    nmax=input.max()

    # put min and max in dictionary
    minmaxes.append({'min':nmin, 'max':nmax})
    input = normalize(input, pmin, pmax)

    #input = (input.astype('float32')-nmin) / (nmax - nmin)
    inputs.append(input)
    dh.show_xyz_slice(input, 'tnia')
../_images/9b699b94900ca2f1987b5fb4e6fb5ceb8a4da774cc9790be1d756607e2989bc9.png ../_images/b1da3c1fa6c4a52375140063ab4680999e8356fe93ef09bc4fd0ab583997f595.png ../_images/28537a9a2a6e8cf291379ff680489a047759aeab9f71fd6e151fedb74e75cc66.png

Load the model#

Load the model that we trained in the previous notebook

# load the model using the name of the model and the path to the model

#model_name = 'cytopacq'
model_name = 'spheres'
#model_name = 'combined'

model = CARE(None, model_name, basedir='../../models')
Loading network weights from 'weights_best.h5'.

Apply the model#

restorations = []

for input in inputs:
    restored = model.predict(input, axes='ZYX', n_tiles=(1, 2, 2), normalizer=None)
    restorations.append(restored)
1/1 [==============================] - 0s 325ms/step
1/1 [==============================] - 0s 243ms/step
 25%|██▌       | 1/4 [00:00<?, ?it/s]
1/1 [==============================] - 0s 42ms/step
 50%|█████     | 2/4 [00:00<00:00,  5.99it/s]
1/1 [==============================] - 0s 42ms/step
 75%|███████▌  | 3/4 [00:00<00:00,  4.16it/s]
1/1 [==============================] - 0s 41ms/step
100%|██████████| 4/4 [00:01<00:00,  3.91it/s]
1/1 [==============================] - 0s 44ms/step
 25%|██▌       | 1/4 [00:00<?, ?it/s]
1/1 [==============================] - 0s 50ms/step
 50%|█████     | 2/4 [00:00<00:00,  5.88it/s]
1/1 [==============================] - 0s 42ms/step
 75%|███████▌  | 3/4 [00:00<00:00,  4.16it/s]
1/1 [==============================] - 0s 40ms/step
100%|██████████| 4/4 [00:01<00:00,  3.90it/s]
for restored in restorations:
    dh.show_xyz_slice(restored, 'tnia')
../_images/d9b8f87c75e26f57f3f786bf62ad8ca18767f7505610079cb13d4afe8d5d2f06.png ../_images/a55258690cb60deb29e7be67fd91146551ffc3df7f41be41cdf22ede13590772.png ../_images/d6b32d0bdd6e15d18a4cda78f6ffac393470092d985a1e9a4961523a0731b1e0.png
import napari
import numpy as np
n=2
decon = imread(r'../../data/deep learning testing/deconvolutions/spheres3.tif')
viewer = napari.Viewer()
viewer.add_image(inputs[n], name='input')
viewer.add_image(np.clip(restorations[n],0, restorations[n].max()), name='restored')
viewer.add_image(decon, name='deconvolution')
<Image layer 'deconvolution' at 0x1ffd94bce50>
print(inputs[n].min(), inputs[n].max(), inputs[n].mean())
print(restorations[n].min(), restorations[n].max(), restorations[n].mean())
0.0 1.0 0.012505261
-0.22564787 1.2641255 0.005944527

Plot line profiles to see if relative intensities were preserved#

This is tricky, because we have to ‘denormalize’ the output from the neural network.

import matplotlib.pyplot as plt

nmax=minmaxes[n]['max']
nmin=minmaxes[n]['min']

phantom = imread('../../data/deep learning testing/ground truth/spheres4.tif')
im = inputs[n]#*(nmax-nmin)+nmin
decon = (restorations[n])#*(nmax-nmin)#*inputs[n].sum()/restorations[n].sum()

s=im.shape

fig, ax = plt.subplots(figsize=(18,14))

line_z=int(s[0]//2)
line_y=int(7*s[1]//8)

line=im[line_z,line_y,:]
ax.plot(line, label = 'image')

line=decon[line_z,line_y,:]
ax.plot(line, label='care')

#line=phantom[line_z,line_y,:]
#ax.plot(line, label='truth')

ax.legend()
<matplotlib.legend.Legend at 0x2aefb3a3160>
../_images/bdb05186914030811c0fc67482ba39ce38cae9477704d27d1b5cb63403a5e2b1.png