Source code for pymerlin.reg

# -*- coding: utf-8 -*-
"""
Containes the core registration components for MERLIN. The framework builds on ITK and is heavily inspired by ANTs.
"""

import logging
import os
import pickle

import itk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from .dataIO import create_image, read_image_h5
from .utils import versor_to_euler


[docs]def otsu_filter(image): """Applies an Otsu filter Args: image (itk.Image): Input image Returns: itk.OtsuThresholdImageFilter: Filter """ print("Performing Otsu thresholding") OtsuOutImageType = itk.Image[itk.UC, 3] filt = itk.OtsuThresholdImageFilter[type(image), OtsuOutImageType].New() filt.SetInput(image) filt.Update() return filt
[docs]def versor_reg_summary(registrations, reg_outs, names=None, doprint=True, show_legend=True): """Summarise results from one or more versor registration experiments Args: registrations (list): List of registration objects reg_outs (list): List of dictionaries of registration outputs names (list, optional): Labels for each registration. Defaults to None. doprint (bool, optional): Print output. Defaults to True. show_legend (bool, optional): Show plot legend. Defaults to True. Returns: pandas.DataFrame: Summary of registrations """ df_dict = {} index = ['Trans X', 'Trans Y', 'Trans Z', 'Versor X', 'Versor Y', 'Versor Z', 'Iterations', 'Metric Value'] if doprint: fig, axes = plt.subplots(ncols=3, nrows=3, figsize=(12, 8)) if not names: names = ['Int %d' % x for x in range(len(registrations))] for (reg, reg_out, name) in zip(registrations, reg_outs, names): # Examine the result transform = reg.GetTransform() optimizer = reg.GetOptimizer() final_parameters = transform.GetParameters() versorX = final_parameters[0] versorY = final_parameters[1] versorZ = final_parameters[2] transX = final_parameters[3] transY = final_parameters[4] transZ = final_parameters[5] nits = optimizer.GetCurrentIteration() best_val = optimizer.GetValue() # Summarise data and store in dictionary reg_data = [transX, transY, transZ, versorX, versorY, versorZ, nits, best_val] df_dict[name] = reg_data # Creat plots if doprint: ax = axes[0, 0] ax.plot(reg_out['cv'], '-o') ax.set_ylabel('') ax.set_title('Optimizer Value') ax.grid('on') ax = axes[0, 1] ax.plot(reg_out['lrr'], '-o') ax.set_ylabel('') ax.set_title('Learning Rate Relaxation') ax.grid('on') ax = axes[0, 2] ax.plot(reg_out['sl'], '-o', label=name) ax.set_ylabel('') ax.set_title('Step Length') ax.grid('on') if show_legend: ax.legend() ax = axes[1, 0] ax.plot(reg_out['tX'], '-o') ax.set_ylabel('[mm]') ax.set_title('Translation X') ax.grid('on') ax = axes[1, 1] ax.plot(reg_out['tY'], '-o') ax.set_ylabel('[mm]') ax.set_title('Translation Y') ax.grid('on') ax = axes[1, 2] ax.plot(reg_out['tZ'], '-o') ax.set_ylabel('[mm]') ax.set_title('Translation Z') ax.grid('on') ax = axes[2, 0] ax.plot(reg_out['vX'], '-o') ax.set_xlabel('Itteration') ax.set_ylabel('') ax.set_title('Versor X') ax.grid('on') ax = axes[2, 1] ax.plot(reg_out['vY'], '-o') ax.set_xlabel('Itteration') ax.set_ylabel('') ax.set_title('Versor Y') ax.grid('on') ax = axes[2, 2] ax.plot(reg_out['vZ'], '-o') ax.set_xlabel('Itteration') ax.set_ylabel('') ax.set_title('Versor Z') ax.grid('on') if doprint: plt.tight_layout() plt.show() # Create Dataframe with output data df = pd.DataFrame(df_dict, index=index) # Determine if running in notebook or shell to get right print function env = os.environ program = os.path.basename(env['_']) if doprint: print(df) return df
[docs]def versor_watcher(reg_out, optimizer): """Logging for registration Args: reg_out (dict): Structure for logging registration optimizer (itk.RegularStepGradientDescentOptimizerv4): Optimizer object Returns: function: Logging function """ logging.debug("{:s} \t {:6s} \t {:6s} \t {:6s} \t {:6s} \t {:6s} \t {:6s} \t {:6s}".format( 'Itt', 'Value', 'vX', 'vY', 'vZ', 'tX', 'tY', 'tZ')) def opt_watcher(): cv = optimizer.GetValue() cpos = np.array(optimizer.GetCurrentPosition()) cit = optimizer.GetCurrentIteration() lrr = optimizer.GetCurrentLearningRateRelaxation() sl = optimizer.GetCurrentStepLength() # Store logged values reg_out['cv'].append(cv) reg_out['vX'].append(cpos[0]) reg_out['vY'].append(cpos[1]) reg_out['vZ'].append(cpos[2]) reg_out['tX'].append(cpos[3]) reg_out['tY'].append(cpos[4]) reg_out['tZ'].append(cpos[5]) reg_out['sl'].append(sl) reg_out['lrr'].append(lrr) logging.debug("{:d} \t {:6.5f} \t {:6.3f} \t {:6.3f} \t {:6.3f} \t {:6.3f} \t {:6.3f} \t {:6.3f}".format( cit, cv, cpos[0], cpos[1], cpos[2], cpos[3], cpos[4], cpos[5])) return opt_watcher
[docs]def winsorize_image(image, p_low, p_high): """Applies winsorize filter to image Args: image (itk.Image): Input image p_low (float): Lower percentile p_high (float): Upper percentile Returns: itk.ThresholdImageFilter: Threshold filter """ Dimension = 3 PixelType = itk.template(image)[1][0] ImageType = itk.Image[PixelType, Dimension] # Histogram nbins = 1000 # Allows 0.001 precision hist_filt = itk.ImageToHistogramFilter[ImageType].New() hist_filt.SetInput(image) hist_filt.SetAutoMinimumMaximum(True) hist_filt.SetHistogramSize([nbins]) hist_filt.Update() hist = hist_filt.GetOutput() low_lim = hist.Quantile(0, p_low) high_lim = hist.Quantile(0, p_high) filt = itk.ThresholdImageFilter[ImageType].New() filt.SetInput(image) filt.ThresholdBelow(low_lim) filt.ThresholdAbove(high_lim) filt.ThresholdOutside(low_lim, high_lim) return filt
[docs]def threshold_image(image, low_lim): """Threshold image at given value Args: image (itk.Image): Input image low_lim (float): Lower threshold Returns: itk.Image: Thresholded image """ Dimension = 3 PixelType = itk.template(image)[1][0] ImageType = itk.Image[PixelType, Dimension] thresh_filt = itk.ThresholdImageFilter[ImageType].New() thresh_filt.ThresholdBelow(float(low_lim)) thresh_filt.SetOutsideValue(0) thresh_filt.SetInput(image) thresh_filt.Update() return thresh_filt.GetOutput()
[docs]def resample_image(registration, moving_image, fixed_image): """Resample image with registration parameters Args: registration (itk.ImageRegistrationMethodv4): Registration object moving_image (itk.Image): Moving image fixed_image (itk.Image): Fixed image Returns: itk.ResampleImageFilter: Resampler filter """ logging.info("Resampling moving image") transform = registration.GetTransform() final_parameters = transform.GetParameters() TransformType = itk.VersorRigid3DTransform[itk.D] finalTransform = TransformType.New() finalTransform.SetFixedParameters( registration.GetOutput().Get().GetFixedParameters()) finalTransform.SetParameters(final_parameters) ResampleFilterType = itk.ResampleImageFilter[type(moving_image), type(moving_image)] resampler = ResampleFilterType.New() resampler.SetTransform(finalTransform) resampler.SetInput(moving_image) resampler.SetSize(fixed_image.GetLargestPossibleRegion().GetSize()) resampler.SetOutputOrigin(fixed_image.GetOrigin()) resampler.SetOutputSpacing(fixed_image.GetSpacing()) resampler.SetOutputDirection(fixed_image.GetDirection()) resampler.SetDefaultPixelValue(0) resampler.Update() return resampler
[docs]def get_versor_factors(registration): """Calculate correction factors from Versor object Args: registration (itk.ImageRegistrationMethodv4): Registration object Returns: dict: Correction factors """ transform = registration.GetTransform() final_parameters = transform.GetParameters() TransformType = itk.VersorRigid3DTransform[itk.D] finalTransform = TransformType.New() finalTransform.SetFixedParameters( registration.GetOutput().Get().GetFixedParameters()) finalTransform.SetParameters(final_parameters) matrix = itk.array_from_matrix(finalTransform.GetMatrix()) offset = np.array(finalTransform.GetOffset()) regParameters = registration.GetOutput().Get().GetParameters() corrections = {'R': matrix, 'vx': regParameters[0], 'vy': regParameters[1], 'vz': regParameters[2], 'dx': regParameters[3], 'dy': regParameters[4], 'dz': regParameters[5] } return corrections
[docs]def setup_optimizer(PixelType, opt_range, relax_factor, nit=250, learning_rate=0.1, convergence_window_size=10, convergence_value=1E-6, min_step_length=1E-4): """Setup optimizer object Args: PixelType (itkCType): ITK pixel type opt_range (list): Range for optimizer relax_factor (float): Relaxation factor nit (int, optional): Number of iterations. Defaults to 250. learning_rate (float, optional): Optimizer learning rate. Defaults to 0.1. convergence_window_size (int, optional): Number of points to use in evaluating convergence. Defaults to 10. convergence_value ([type], optional): Value at which convergence is reached. Defaults to 1E-6. Returns: itk.RegularStepGradientDescentOptimizerv4: Optimizer object """ logging.info("Initialising Regular Step Gradient Descent Optimizer") optimizer = itk.RegularStepGradientDescentOptimizerv4[PixelType].New() OptimizerScalesType = itk.OptimizerParameters[PixelType] # optimizerScales = OptimizerScalesType( # initialTransform.GetNumberOfParameters()) optimizerScales = OptimizerScalesType(6) # Set scales <- Not sure about this part rotationScale = 1.0/np.deg2rad(opt_range[0]) translationScale = 1.0/opt_range[1] optimizerScales[0] = rotationScale optimizerScales[1] = rotationScale optimizerScales[2] = rotationScale optimizerScales[3] = translationScale optimizerScales[4] = translationScale optimizerScales[5] = translationScale optimizer.SetScales(optimizerScales) logging.info("Setting up optimizer") logging.info("Rot/Trans scales: {}/{}".format(opt_range[0], opt_range[1])) logging.info("Number of itterations: %d" % nit) logging.info("Learning rate: %.2f" % learning_rate) logging.info("Relaxation factor: %.2f" % relax_factor) logging.info("Convergence window size: %d" % convergence_window_size) logging.info("Convergence value: %f" % convergence_value) optimizer.SetNumberOfIterations(nit) optimizer.SetLearningRate(learning_rate) # Default in ANTs optimizer.SetRelaxationFactor(relax_factor) optimizer.SetConvergenceWindowSize(convergence_window_size) optimizer.SetMinimumConvergenceValue(convergence_value) optimizer.SetMinimumStepLength(min_step_length) return optimizer
[docs]def versor3D_registration(fixed_image_fname, moving_image_fname, moco_output_name=None, fixed_output_name=None, fixed_mask_fname=None, reg_par_name=None, iteration_log_fname=None, opt_range=[np.deg2rad(1), 10], init_angle=0, init_axis=[0, 0, 1], relax_factor=0.5, winsorize=None, threshold=None, sigmas=[0], shrink=[1], metric='MS', learning_rate=5, convergence_window_size=10, convergence_value=1E-6, min_step_length=1E-6, nit=250, verbose=2): """Multi-scale rigid body registration ITK registration framework inspired by ANTs which performs a multi-scale 3D versor registratio between two 3D volumes. The input data is provided as .h5 image files. Default values works well. Mask for the fixed image is highly recommended for ZTE data with head rest pads visible. Note that the outputs of the registration is are versor and translation vectors. The versor is the vector part of a unit normalised quarterion. To get the equivalent euler angles use pymerlin.utils.versor_to_euler. Args: fixed_image_fname (str): Fixed file (.h5 file) moving_image_fname (str): Moving file (.h5 file) moco_output_name (str, optional): Output moco image as nifti. Defaults to None. fixed_output_name (str, optional): Output fixed image as nifti. Defaults to None. fixed_mask_name (str, optional): Mask for fixed image. Defaults to None reg_par_name (str, optional): Name of output parameter file. Defaults to None. iteration_log_fname (str, optional): Name for output log file. Defaults to None. opt_range (list, optional): Expected range of motion [deg,mm]. Defaults to [1 rad, 10 mm]. init_angle (float, optional): Initial angle for registration. Defaults to 0 init_axis (array, optional): Direction of intial rotation for registration. Defaults to [0,0,1] relax_factor (float, optional): Relaxation factor for optimizer, factor to decrease step length by. Defaults to 0.5. winsorize (list, optional): Limits for winsorize filter. Defaults to None. threshold (float, optional): Lower value for threshold filter. Defaults to None. sigmas (list, optional): Smoothing sigmas for multi-scale registration. Defaults to [0]. shrink (list, optional): Shring factors for multi-scale registration. Defaults to [1]. metric (str, optional): Image metric for registrationn (MI/MS). Defaults to 'MS'. learning_rate (float, optional): Initial step length. Defaults to 5. convergence_window_size (int, optional): Length of window to calculate convergence value. Defaults to 10. convergence_value (float, optional): Convergence value to terminate registration. Defaults to 1E-6. min_step_length (float, optional): Minimum step length, after which the registration terminates. Defaults to 1E-6, nit (int, optional): Maximum number of iterations per scale. Defaults to 250. verbose (int, optional): Level of debugging (0/1/2). Defaults to 2. Returns: (itk.ImageRegistrationMethodv4, dict, str): Registration object, Registration history, Name of output file with correction factors """ # Logging log_level = {0: None, 1: logging.INFO, 2: logging.DEBUG} logging.basicConfig( format="[%(asctime)s] %(levelname)s: %(message)s", level=log_level[verbose], datefmt="%I:%M:%S") # Global settings. 3D data with itk.D type PixelType = itk.D ImageType = itk.Image[PixelType, 3] # Validate inputs if len(sigmas) != len(shrink): logging.error("Sigma and Shrink arrays not the same length") raise Exception("Sigma and Shrink arrays must be same length") # Read in data logging.info("Reading fixed image: {}".format(fixed_image_fname)) data_fix, spacing_fix = read_image_h5(fixed_image_fname) logging.info("Reading moving image: {}".format(moving_image_fname)) data_move, spacing_move = read_image_h5(moving_image_fname) fixed_image = create_image( data_fix, spacing_fix, dtype=PixelType, max_image_value=1E3) moving_image = create_image( data_move, spacing_move, dtype=PixelType, max_image_value=1E3) # Winsorize filter if winsorize: logging.info("Winsorising images") fixed_win_filter = winsorize_image( fixed_image, winsorize[0], winsorize[1]) moving_win_filter = winsorize_image( moving_image, winsorize[0], winsorize[1]) fixed_image = fixed_win_filter.GetOutput() moving_image = moving_win_filter.GetOutput() if threshold == 'otsu': logging.info("Calculating Otsu filter") filt = otsu_filter(fixed_image) otsu_threshold = filt.GetThreshold() logging.info( "Applying thresholding at Otsu threshold of {}".format(otsu_threshold)) fixed_image = threshold_image(fixed_image, otsu_threshold) moving_image = threshold_image(moving_image, otsu_threshold) elif threshold is not None: logging.info("Thresholding images at {}".format(threshold)) fixed_image = threshold_image(fixed_image, threshold) moving_image = threshold_image(moving_image, threshold) # Setup image metric if metric == 'MI': nbins = 16 logging.info( "Using Mattes Mutual Information image metric with {} bins".format(nbins)) metric = itk.MattesMutualInformationImageToImageMetricv4[ImageType, ImageType].New() metric.SetNumberOfHistogramBins(nbins) metric.SetUseMovingImageGradientFilter(False) metric.SetUseFixedImageGradientFilter(False) else: logging.info("Using Mean Squares image metric") metric = itk.MeanSquaresImageToImageMetricv4[ImageType, ImageType].New( ) # Setup versor transform logging.info("Initialising Versor Rigid 3D Transform") TransformType = itk.VersorRigid3DTransform[PixelType] TransformInitializerType = itk.CenteredTransformInitializer[TransformType, ImageType, ImageType] initialTransform = TransformType.New() initializer = TransformInitializerType.New() initializer.SetTransform(initialTransform) initializer.SetFixedImage(fixed_image) initializer.SetMovingImage(moving_image) initializer.GeometryOn() initializer.InitializeTransform() VersorType = itk.Versor[itk.D] VectorType = itk.Vector[itk.D, 3] rotation = VersorType() axis = VectorType() axis[0] = init_axis[0] axis[1] = init_axis[1] axis[2] = init_axis[2] angle = init_angle rotation.Set(axis, angle) initialTransform.SetRotation(rotation) # Setup optimizer optimizer = setup_optimizer(PixelType, opt_range, relax_factor, nit=int(nit), learning_rate=learning_rate, convergence_window_size=int( convergence_window_size), convergence_value=convergence_value, min_step_length=min_step_length) # Setup registration registration = itk.ImageRegistrationMethodv4[ImageType, ImageType].New() registration.SetMetric(metric) registration.SetOptimizer(optimizer) registration.SetFixedImage(fixed_image) registration.SetMovingImage(moving_image) registration.SetInitialTransform(initialTransform) # One level registration without shrinking and smoothing logging.info("Smoothing sigmas: {}".format(sigmas)) logging.info("Shrink factors: {}".format(shrink)) numberOfLevels = len(sigmas) shrinkFactorsPerLevel = itk.Array[itk.F](numberOfLevels) smoothingSigmasPerLevel = itk.Array[itk.F](numberOfLevels) for i in range(numberOfLevels): shrinkFactorsPerLevel[i] = shrink[i] smoothingSigmasPerLevel[i] = sigmas[i] registration.SetNumberOfLevels(numberOfLevels) registration.SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel) registration.SetShrinkFactorsPerLevel(shrinkFactorsPerLevel) if fixed_mask_fname: logging.info( "Loading fixed mask from file: {}".format(fixed_mask_fname)) MaskType = itk.ImageMaskSpatialObject[3] mask = MaskType.New() data_mask_fix, spacing_mask_fix = read_image_h5(fixed_mask_fname) mask_img = create_image( data_mask_fix, spacing_mask_fix, dtype=itk.UC) mask.SetImage(mask_img) mask.Update() metric.SetFixedImageMask(mask) # Watch the itteration events reg_out = {'cv': [], 'tX': [], 'tY': [], 'tZ': [], 'vX': [], 'vY': [], 'vZ': [], 'sl': [], 'lrr': []} logging.info("Running Registration") wf = versor_watcher(reg_out, optimizer) optimizer.AddObserver(itk.IterationEvent(), wf) # --> Run registration registration.Update() # Correction factors corrections = get_versor_factors(registration) rot_x, rot_y, rot_z = versor_to_euler( [corrections['vx'], corrections['vy'], corrections['vz']]) logging.info("Estimated parameters") logging.info("Rotation: (%.2f, %.2f, %.2f) deg" % (np.rad2deg(rot_x), np.rad2deg(rot_y), np.rad2deg(rot_z))) logging.info("Translation: (%.2f, %.2f, %.2f) mm" % (corrections['dx'], corrections['dy'], corrections['dz'])) # Resample moving data resampler = resample_image(registration, moving_image, fixed_image) # Write output if moco_output_name: logging.info( "Writing moco output image to: {}".format(moco_output_name)) writer = itk.ImageFileWriter[ImageType].New() writer.SetFileName(moco_output_name) writer.SetInput(resampler.GetOutput()) writer.Update() if fixed_output_name: logging.info( "Writing reference imgae to: {}".format(fixed_output_name)) writer = itk.ImageFileWriter[ImageType].New() writer.SetFileName(fixed_output_name) writer.SetInput(fixed_image) writer.Update() if iteration_log_fname: logging.info("Writing iteration log to: %s" % iteration_log_fname) pickle.dump(reg_out, open(iteration_log_fname, "wb")) if not reg_par_name: fix_bname = os.path.splitext(os.path.basename(fixed_image_fname)) move_bname = os.path.splitext(os.path.basename(moving_image_fname)) reg_par_name = "%s_2_%s_reg.p" % (move_bname, fix_bname) logging.info("Writing registration parameters to: %s" % reg_par_name) with open(reg_par_name, 'wb') as f: pickle.dump(corrections, f) return registration, reg_out, reg_par_name
[docs]def histogram_threshold_estimator(img, plot=False, nbins=200): """Estimate background intensity using histogram. Initially used to reduce streaking in background but found to make little difference. Args: img (np.array): Image plot (bool, optional): Plot result. Defaults to False. nbins (int, optional): Number of histogram bins. Defaults to 200. """ def smooth(x, window_len=10, window='hanning'): """smooth the data using a window with requested size. This method is based on the convolution of a scaled window with the signal. The signal is prepared by introducing reflected copies of the signal (with the window size) in both ends so that transient parts are minimized in the begining and end part of the output signal. input: x: the input signal window_len: the dimension of the smoothing window; should be an odd integer window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman' flat window will produce a moving average smoothing. output: the smoothed signal example: t=linspace(-2,2,0.1) x=sin(t)+randn(len(t))*0.1 y=smooth(x) see also: numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve scipy.signal.lfilter TODO: the window parameter could be the window itself if an array instead of a string NOTE: length(output) != length(input), to correct this: return y[(window_len/2-1):-(window_len/2)] instead of just y. -> Obtained from the scipy cookbook at: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html Modified to use np instead of numpy """ if x.ndim != 1: raise(ValueError, "smooth only accepts 1 dimension arrays.") if x.size < window_len: raise(ValueError, "Input vector needs to be bigger than window size.") if window_len < 3: return x s = np.r_[x[window_len-1:0:-1], x, x[-2:-window_len-1:-1]] # print(len(s)) if window == 'flat': # moving average w = np.ones(window_len, 'd') else: w = eval('np.'+window+'(window_len)') y = np.convolve(w/w.sum(), s, mode='valid') return y[int(window_len/2-1):-int(window_len/2)] y, x = np.histogram(abs(img.flatten()), bins=nbins) x = (x[1:]+x[:-1])/2 y = smooth(y) dx = (x[1:]+x[:-1])/2 dx2 = (dx[1:]+dx[:-1])/2 dy = np.diff(y) dy2 = np.diff(smooth(dy)) # Peak of histogram imax = np.argmax(y) # Find max of second derivative after this peak dy2max = np.argmax(dy2[imax:]) thr = int(dx2[imax+dy2max]) if plot: plt.figure() plt.plot(x, y/max(y), label='H') plt.plot(dx, dy/np.max(dy), label='dH/dx') ldy2 = plt.plot(dx2, dy2/max(abs(dy2)), label=r'$dH^2/dx^2$') plt.axis([0, 1500, -1, 1]) thr = int(dx2[imax+dy2max]) plt.plot([dx2[imax+dy2max], dx2[imax+dy2max]], [-1, 1], '--', color=ldy2[0].get_color(), label='Thr=%d' % thr) plt.legend() plt.show() return thr
################################################################# # Legacy functions #################################################################
[docs]def versor_resample(registration, moving_image, fixed_image): Dimension = 3 PixelType = itk.D FixedImageType = itk.Image[PixelType, Dimension] MovingImageType = itk.Image[PixelType, Dimension] transform = registration.GetTransform() final_parameters = transform.GetParameters() TransformType = itk.VersorRigid3DTransform[itk.D] finalTransform = TransformType.New() finalTransform.SetFixedParameters( registration.GetOutput().Get().GetFixedParameters()) finalTransform.SetParameters(final_parameters) ResampleFilterType = itk.ResampleImageFilter[MovingImageType, FixedImageType] resampler = ResampleFilterType.New() resampler.SetTransform(finalTransform) resampler.SetInput(moving_image) resampler.SetSize(fixed_image.GetLargestPossibleRegion().GetSize()) resampler.SetOutputOrigin(fixed_image.GetOrigin()) resampler.SetOutputSpacing(fixed_image.GetSpacing()) resampler.SetOutputDirection(fixed_image.GetDirection()) resampler.SetDefaultPixelValue(0) resampler.Update() return resampler.GetOutput()