# -*- coding: utf-8 -*-
"""
Plotting tools for 3D and 4D data.
"""
from builtins import ValueError
import os
import pickle
import warnings
import imageio
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from matplotlib.backends.backend_agg import FigureCanvasAgg
from tqdm import tqdm
from .utils import parse_combreg
[docs]def plot_3plane(I, title='', cmap='gray', vmin=None, vmax=None):
"""3 plane plot of 3D image data
Args:
I (array): 3D image array
title (str, optional): Plot title. Defaults to ''.
cmap (str, optional): colormap. Defaults to 'gray'.
vmin (int, optional): Lower window limit. Defaults to None.
vmax (int, optional): Upper window limit. Defaults to None.
"""
[nx, ny, nz] = np.shape(I)
fig = plt.figure(figsize=(12, 6), facecolor='black')
fig.add_subplot(1, 3, 1)
plt.imshow(I[int(nx/2), :, :], cmap=cmap, vmin=vmin, vmax=vmax)
plt.axis('off')
plt.title(' ')
fig.add_subplot(1, 3, 2)
plt.imshow(I[:, int(ny/2), :], cmap=cmap, vmin=vmin, vmax=vmax)
plt.axis('off')
plt.title(title, color='w', size=20)
fig.add_subplot(1, 3, 3)
plt.imshow(I[:, :, int(nz/2)], cmap=cmap, vmin=vmin, vmax=vmax)
plt.axis('off')
plt.title(' ')
# plt.show()
[docs]def timeseries_video(img_ts, interval=100, title=''):
"""Show a time series as a video in a Jupyter Notebook
To view animation in notebook run
```python
from IPython.display import HTML
video = timeseries_video(TS)
HTML(video)
```
Args:
img_ts (array): Time series sliced to desired view [nx,ny,nt]
interval (int, optional): Framerate. Defaults to 100.
title (str, optional): Title. Defaults to ''.
Returns:
video: HTML video object
"""
fig, ax = plt.subplots()
img = ax.imshow(img_ts[:, :, 0], cmap='gray')
plt.axis('off')
plt.title('')
[nx, ny, nt] = np.shape(img_ts)
def init():
img.set_data(img_ts[:, :, 0])
return (img,)
def animate(i):
t = i
img.set_data(img_ts[:, :, t])
ax.set_title('%s (t=%d/%d)' % (title, t, nt-1))
return (img,)
anim = animation.FuncAnimation(fig, animate, init_func=init,
frames=nt, interval=interval, blit=True)
plt.close()
return anim.to_html5_video()
[docs]def imshow3(I, ncol=None, nrow=None, cmap='gray', vmin=None, vmax=None, order='col'):
"""Tiled plot of 3D data
Args:
I (3D array): Data to plot, expanding along last dimentsion
ncol (int, optional): Number of columns. Defaults to None.
nrow (int, optional): Number of rows. Defaults to None.
cmap (str, optional): Matplotlob colormap. Defaults to 'gray'.
vmin (float, optional): Color range lower limit. Defaults to None.
vmax (float, optional): Color range higher limit. Defaults to None.
order (str, optional): Row or column order. Defaults to 'col'.
Returns:
np.array: Tiled array
"""
"""Multi-pane plot of 3D data
Inspired by the matlab function imshow3.
https://github.com/mikgroup/espirit-matlab-examples/blob/master/imshow3.m
Expands the 3D data along the last dimension. Data is shown on the current matplotlib axis.
- ncol: Number of columns
- nrow: Number of rows
- cmap: colormap ('gray')
Output:
- I3: Same image as shown
Args:
I (array): 3D array with 2D images stacked along last dimension
ncol (int, optional): Number of columns. Defaults to None.
nrow (int, optional): Number of rows. Defaults to None.
cmap (str, optional): Colormap. Defaults to 'gray'.
vmin (innt, optional): Lower window limit. Defaults to None.
vmax (int, optional): Upper window limit. Defaults to None.
order (str, optional): Plot order 'col/row'. Defaults to 'col'.
Returns:
array: Image expanded along the third dimension
"""
[nx, ny, n] = np.shape(I)
if (not nrow) and (not ncol):
nrow = int(np.floor(np.sqrt(n)))
ncol = int(n/nrow)
elif not ncol:
ncol = int(np.ceil(n/nrow))
elif not nrow:
nrow = int(np.ceil(n/ncol))
I3 = np.zeros((ny*nrow, nx*ncol))
i = 0
if order == 'col':
for ix in range(ncol):
for iy in range(nrow):
try:
I3[iy*ny:(iy+1)*ny, ix*nx:(ix+1)*nx] = I[:, :, i]
except:
warnings.warn('Warning: Empty slice. Setting to 0 instead')
continue
i += 1
else:
for iy in range(nrow):
for ix in range(ncol):
try:
I3[iy*ny:(iy+1)*ny, ix*nx:(ix+1)*nx] = I[:, :, i]
except:
warnings.warn('Warning: Empty slice. Setting to 0 instead')
continue
i += 1
plt.imshow(I3, cmap=cmap, vmin=vmin, vmax=vmax)
plt.axis('off')
return I3
[docs]def reg_animation(reg_out, images, out_name='animation.gif', slice_pos=None, tnav=None, t0=0, max_d=None, max_r=None, vmax=1, nrot=0):
"""Animation of registration results
Args:
reg_out (str): File name to pickle file
images (np.array): Array of images to display (nx,ny,nz,nt)
out_name (str, optional): Output suffix .gif or .mp4. Defaults to 'animation.gif'.
slice_pos (array, optional): Slice positions. Defaults to center slices.
tnav (float, optional): Navigator duration to display x-axis with time. Defaults to None.
t0 (int, optional): Starting time. Defaults to 0.
max_d (float, optional): Limit of translation plot. Defaults to None.
max_r (float, optional): Limit of translation plot. Defaults to None.
vmax (float, optional): Maximum display range (0-1). Defaults to 1
Raises:
TypeError: If number of navigator images in `images` is not the same as number of registration objects in pickle file.
"""
combreg = pickle.load(open(reg_out, 'rb'))
num_navigators = images.shape[3]
if len(combreg) != num_navigators:
raise ValueError(
f'Length of combreg ({len(combreg)}) is not the same as number of navigator images ({images.shape[3]})')
all_reg = parse_combreg(combreg)
plt.style.use('default')
plt.rcParams.update({'font.size': 14})
# Time axis
if tnav:
plot_xlabel = 'Time [s]'
t = np.arange(num_navigators)*tnav + t0
else:
plot_xlabel = 'Navigator'
t = np.arange(num_navigators)
if not max_d:
max_d = np.ceil(
np.max(np.abs([all_reg['dx'], all_reg['dy'], all_reg['dz']])))
d_axis = [0, max(t), -max_d, max_d]
if not max_r:
max_r = np.ceil(np.rad2deg(
np.max(np.abs([all_reg['rx'], all_reg['ry'], all_reg['rz']]))))
r_axis = [0, max(t), -max_r, max_r]
nx, ny, nz, _ = images.shape
if slice_pos[0] == None:
slice_pos[0] = int(nx/2)
if slice_pos[1] == None:
slice_pos[1] = int(ny/2)
if slice_pos[2] == None:
slice_pos[2] = int(nz/2)
use_raster = True
raster_order = -10
max_img = np.max(abs(images[:]))
def plot_frame(img_idx):
fig = plt.figure(constrained_layout=True, figsize=(12, 8))
canvas = FigureCanvasAgg(fig)
spec = gridspec.GridSpec(ncols=2, nrows=3, figure=fig)
axes = {}
d_ax = fig.add_subplot(spec[0, 0], rasterized=use_raster)
r_ax = fig.add_subplot(spec[0, 1], rasterized=use_raster)
for (i, ax) in enumerate(['x', 'y', 'z']):
d_ax.plot(t, all_reg['d%s' % ax], linewidth=3,
color='C%d' % i, label=ax)
d_ax.plot([t[img_idx], t[img_idx]], [-max_d, max_d], '--k')
plt.gca().set_rasterization_zorder(raster_order)
r_ax.plot(t, np.rad2deg(
all_reg['r%s' % ax]), linewidth=3, color='C%d' % i)
r_ax.plot([t[img_idx], t[img_idx]], [-max_r, max_r], '--k')
plt.gca().set_rasterization_zorder(raster_order)
d_ax.set_title('Translation')
d_ax.set_ylabel(r'$\Delta_%s$ [mm]' % ax)
d_ax.set_xlabel(plot_xlabel)
d_ax.axis(d_axis)
d_ax.grid()
d_ax.legend(loc='upper left')
r_ax.set_title('Rotation')
r_ax.set_ylabel(r'$\alpha_%s$ [deg]' % ax)
r_ax.set_xlabel(plot_xlabel)
r_ax.axis(r_axis)
r_ax.grid()
axes['img'] = fig.add_subplot(spec[1:3, 0:2], rasterized=use_raster)
img_view = np.concatenate([
np.rot90(images[slice_pos[0], :, :, img_idx], nrot),
np.rot90(images[:, slice_pos[1], :, img_idx], nrot),
np.rot90(images[:, :, slice_pos[2], img_idx], nrot)], axis=1)
plt.imshow(img_view/max_img*255, cmap='gray', vmin=0, vmax=vmax*255)
plt.title('Navigator (%d/%d)' % (img_idx+1, num_navigators))
plt.axis('off')
plt.gca().set_rasterization_zorder(raster_order)
return (fig, canvas)
# Determine output type
suffix = os.path.splitext(os.path.basename(out_name))[-1]
out_type = None
if suffix == '.gif':
out_type = 'gif'
elif suffix == '.mp4':
out_type = 'mp4'
else:
raise ValueError("Output name must be .gif or .mp4")
# Produce the frames
frames = []
if out_type == 'mp4':
writer = imageio.get_writer(
out_name, format='FFMPEG', mode='I', fps=10)
print("Processing frames")
for i in tqdm(range(num_navigators)):
fig, canvas = plot_frame(i)
canvas.draw()
buf = canvas.buffer_rgba()
X = np.asarray(buf, dtype=np.uint8)
if out_type == 'mp4':
writer.append_data(X)
else:
frames.append(X)
plt.close(fig)
# Scale data
# uint8_frames = []
# for i in range(num_navigators):
# uint8_frames.append(np.array(frames[i], dtype=np.uint8))
print("Saving output to: {}".format(out_name))
if out_type == 'mp4':
writer.close()
else:
imageio.mimsave(out_name, frames)
[docs]def report_plot(combreg, maxd, maxr, navtr=None, bw=False):
"""Plot registration results
Args:
combreg (str): Filename of registration results
maxd (float): Max displacement for y-limit
maxr (float): Max rotation for y-limit
navtr (float, optional): Duration of navigator for time axis. Defaults to None.
bw (bool, optional): Plot in black and white. Defaults to False.
"""
# Summarise statistics
all_reg = parse_combreg(combreg)
if bw:
plt.style.use('grayscale')
else:
plt.style.use('default')
fig = plt.figure(figsize=(12, 4), facecolor='w')
plt.rcParams.update({'font.size': 16})
# Axis limits
max_d = float(maxd)
max_r = float(maxr)
if not max_d:
max_d = np.ceil(
np.max(np.abs([all_reg['dx'], all_reg['dy'], all_reg['dz']])))
if not max_r:
max_r = np.ceil(np.rad2deg(
np.max(np.abs([all_reg['rx'], all_reg['ry'], all_reg['rz']]))))
x = list(range(len(combreg)))
if navtr:
x *= navtr
d_ax = fig.add_subplot(1, 2, 1)
rot_ax = fig.add_subplot(1, 2, 2)
for (i, ax) in enumerate(['x', 'y', 'z']):
d_ax.plot(x, all_reg['d%s' % ax],
linewidth=3, label=ax,)
rot_ax.plot(x, np.rad2deg(all_reg['r%s' % ax]),
linewidth=3, label=ax)
d_ax.axis([0, max(x), -max_d, max_d])
d_ax.set_ylabel(r'$\Delta$ [mm]')
d_ax.grid()
d_ax.set_title('Translation')
d_ax.legend()
rot_ax.axis([0, max(x), -max_r, max_r])
rot_ax.grid()
rot_ax.set_ylabel(r'$\alpha$ [deg]')
rot_ax.set_title('Rotation')
rot_ax.legend()
if navtr:
rot_ax.set_xlabel('Time [s]')
d_ax.set_xlabel('Time [s]')
else:
rot_ax.set_xlabel('Navigator')
d_ax.set_xlabel('Navigator')
plt.tight_layout()