Stardist compute receptive field#

When trying to train a stardist model it is useful to check if the configured neural network has a large enough field of view to detect the objects in the images. This stardist 3D training example notebook from the stardist github repo, is a great starting point for creating customized stardist workflows. It contains the below code snippet to calculate the network field of view.

median_size = calculate_extents(Y, np.median)
fov = np.array(model._axes_tile_overlap('ZYX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

A small issue, it may make more sense to calculate the max extent, rather than median. Another issue is that the field of view appears to max out at 32 in all dimensions. Try creating a network that has a larger field of view and you still get 32.

This notebook looks into the stardist ‘field of view’ calculation in a little more detail, with the goal to compute larger 3D field of view values.

First our imports#

from stardist.models import Config2D, StarDist2D
import raster_geometry as rg
from stardist import Rays_GoldenSpiral, Rays_Octo
from tnia.plotting.plt_helper import imshow2d
from stardist import calculate_extents
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

Create a 2D ‘object’#

This is just a circle. Make the circle bigger if you want to design a network for bigger objects. In a real project we’d have labels with many objects in them. Here we just use this very simple ‘label’ as an example of the biggest object we want the system to handle.

radius = 75
xy_spacing = 1

image_size = 256
xy_pixels_radius = int(radius / xy_spacing)

print(f'xy_pixels_radius {xy_pixels_radius}')

ellipsoid2D = rg.ellipse([image_size, image_size], [xy_pixels_radius, xy_pixels_radius]).astype('uint8')*255

fig = imshow2d(ellipsoid2D, 3,3)
fig.suptitle('ellipsoid2D')
xy_pixels_radius 75
Text(0.5, 0.98, 'ellipsoid2D')
../_images/7fef93bb2da3a007937bf043bd1abf92e63415b5453190099eeb615a7ff17469.png

Calculate the extent and anisotropy of our sphere#

extents2D = calculate_extents(ellipsoid2D)

print('extent 2D', extents2D)
extent 2D [150. 150.]

Design a stardist 2D network#

We make the model_name ‘delete me’ because we aren’t going to train it or anything, so we want delete it after we are done. Then we need to define the grid, unet_kernel_size, and depth. These parameters determine the receptive field.

  1. A larger ‘grid’ means the filter takes a larger stride at the first level, thus ‘see’s more’

  2. A larger kernel size means the network ‘see’s more’ at each level.

  3. At each level the images are downsampled, which means when we apply the kernels at that level they porcess information from a larger spatial extent.

model_name = 'delete me 2D'
grid = (1,1)
unet_kernel_size=3,3
unet_n_filter_base = 4
depth = 4
n_channel_in =1
n_rays = 4
axes = 'YX'

config2D = Config2D (n_rays=16, axes=axes,n_channel_in=1, grid=grid, train_patch_size = (256,256), unet_n_depth=depth, unet_kernel_size=unet_kernel_size)
model2D = StarDist2D(config=config2D, name=model_name, basedir='models')
Using default values: prob_thresh=0.5, nms_thresh=0.4.

Apply the example ‘field of view’ code#

Note we make a couple of changes. First we use max size instead of median. Second we also compute receptive field.

rf = model._compute_receptive_field(img_size=(256, 512, 512))

Note in examples a function called _axes_tile_overlap is used to calculate receptive field. However if you inspect the code, _axes_tile_overlap wraps _compute_receptive_field but uses a fixed image size. Using _compute_receptive_field directly allows us to input an img_size.

That is really important. Stardist calculates the receptive field by applying the real network and empirically measuring how much of the result contains ‘information’. If you use a small image size like 64, 64, 64 the entire image will contain information. To compute receptive field you need to make your img_size larger than the receptive field you expect.

max_size = calculate_extents(ellipsoid2D, np.max)
fov = np.array(model2D._axes_tile_overlap('YX'))
rf = model2D._compute_receptive_field(img_size=(512, 512))
print(f"max object radius:      {max_size/2}")
print(f"network field of view :  {fov}")
print(f"receptive field :  {rf}")
#if any(max_size > rf):
#    print("WARNING: max object size larger than field of view of the neural network.")
max object radius:      [75. 75.]
network field of view :  [64 64]
receptive field :  [(93, 89), (90, 89)]

Look at the receptive field function step by step#

  1. It generates two images, an empty image, and an impulse response. This is really cool. An impulse response is how engineers measure the Point Spread Function of a system. Note that for a neural network the impulse response is non-linear, however we aren’t concerned about the exact shape (this network hasn’t even been trained) but only concerned about ‘how far’ information propagates.

img_size = 256, 256
mid = tuple(s//2 for s in img_size)
x = np.zeros((1,)+img_size+(1,), dtype=np.float32)
z = np.zeros_like(x)
x[(0,)+mid+(slice(None),)] = 1

Make predictions with our inputs#

Our inputs are an impulse and an empty image, the output is the ‘impulse response’ and the offset.

y  = model2D.keras_model.predict(x, verbose=0)[0][0,...,0]
y0 = model2D.keras_model.predict(z, verbose=0)[0][0,...,0]
dif = np.abs(y-y0)

calculate distance information ‘flowed’#

from scipy.ndimage import zoom
print(grid)
yz  = zoom(y, grid,order=0)
y0z = zoom(y0,grid,order=0)
ind = np.where(np.abs(y-y0)>0)

mid = tuple(s//2 for s in img_size)

for i,m in zip(ind,mid):
     print(np.min(i), np.max(i), m-np.min(i), np.max(i)-m)
(1, 1)
35 217 93 89
38 217 90 89

Plot diff image#

We look at the ‘dif’ image as it shows us what parts of the image have information that is different than the constant value resulting form an empty input.

fig = imshow2d(dif, 3, 3, vmax=dif.max()/100)
receptive_field = (dif >0)*255
fig = imshow2d(receptive_field, 3, 3)
../_images/2b031cad5af87e88af05b2b2bd98e1687033c8cae34ac57ac3a65f92d848cd2f.png ../_images/6e148b05ed2cfdef788677e1850ffb26c9abedd0443dfc041984f56b31192082.png

View in Napari#

import napari
viewer = napari.Viewer()
#viewer.add_image(x.squeeze(), name='x')
viewer.add_image(receptive_field, name='receptive field')
viewer.add_labels(ellipsoid2D, name='ellipsoid2D')
<Labels layer 'ellipsoid2D' at 0x1611773c970>