Source code for nighres.brain.dots_segmentation

import os
import time
import numpy as np
import nibabel as nb
from ..io import load_volume, save_volume
from ..utils import _output_dir_4saving, _fname_4saving


atlas_labels_1 = ['isotropic', 'other_WM', 'ATR_L', 'ATR_R', 'CC_front', 
                  'CC_post', 'CC_sup', 'CG_L', 'CG_R', 'CST_L', 'CST_R', 
                  'IFO_L', 'IFO_R', 'ILF_L', 'ILF_R', 'ML_L', 'ML_R', 'OPR_L', 
                  'OPR_R', 'SLF_L', 'SLF_R', 'UNC_L', 'UNC_R']


tract_pair_sets_1 = [{2, 4}, {3, 4}, {2, 6}, {3, 6}, {4, 6}, {4, 7}, {5, 7}, 
                     {6, 7}, {8, 4}, {8, 5}, {8, 6}, {9, 6}, {10, 6}, {9, 10}, 
                     {2, 11}, {11, 4}, {11, 5}, {3, 12}, {12, 4}, {12, 5}, 
                     {10, 12}, {13, 5}, {11, 13}, {5, 14}, {12, 14}, {2, 15}, 
                     {9, 15}, {10, 15}, {16, 9}, {16, 10}, {16, 15}, {17, 5}, 
                     {17, 11}, {17, 13}, {18, 5}, {18, 12}, {18, 14}, {19, 11}, 
                     {19, 13}, {17, 19}, {10, 20}, {20, 12}, {20, 14}, {18, 20}, 
                     {2, 21}, {4, 21}, {11, 21}, {21, 13}, {12, 22}, {22, 14}]


atlas_labels_2 = ['isotropic', 'other_WM', 'ATR_L', 'ATR_R', 'CC_front', 
                  'CC_post', 'CC_sup', 'CG_L', 'CG_R', 'CPT_L', 'CPT_R', 
                  'CST_L', 'CST_R', 'FNX_L', 'FNX_R', 'ICP_L', 'ICP_R', 'IFO_L', 
                  'IFO_R', 'ILF_L', 'ILF_R', 'MCP', 'ML_L', 'ML_R', 'OPR_L', 
                  'OPR_R', 'OPT_L', 'OPT_R', 'PTR_L', 'PTR_R', 'SCP_L', 'SCP_R', 
                  'SFO_L', 'SFO_R', 'SLF_L', 'SLF_R', 'STR_L', 'STR_R', 'TAP', 
                  'UNC_L', 'UNC_R']


tract_pair_sets_2 = [{2, 4}, {3, 4}, {2, 6}, {3, 6}, {4, 6}, {4, 7}, {5, 7}, 
                     {6, 7}, {8, 4}, {8, 5}, {8, 6}, {9, 2}, {9, 6}, {10, 3}, 
                     {10, 6}, {9, 10}, {11, 6}, {9, 11}, {10, 11}, {12, 6}, 
                     {9, 12}, {10, 12}, {11, 12}, {2, 13}, {13, 7}, {9, 13}, 
                     {5, 14}, {8, 14}, {11, 15}, {16, 12}, {17, 2}, {17, 4}, 
                     {17, 5}, {17, 9}, {17, 13}, {18, 3}, {18, 4}, {18, 5}, 
                     {18, 10}, {18, 12}, {18, 14}, {19, 5}, {9, 19}, {19, 13}, 
                     {17, 19}, {20, 5}, {20, 14}, {18, 20}, {9, 21}, {10, 21}, 
                     {11, 21}, {12, 21}, {21, 15}, {16, 21}, {2, 22}, {9, 22}, 
                     {11, 22}, {12, 22}, {22, 15}, {16, 22}, {21, 22}, {10, 23}, 
                     {11, 23}, {12, 23}, {16, 23}, {21, 23}, {22, 23}, {24, 5}, 
                     {24, 9}, {24, 13}, {24, 17}, {24, 19}, {25, 5}, {25, 10}, 
                     {25, 14}, {25, 18}, {25, 20}, {9, 26}, {26, 11}, {26, 13}, 
                     {17, 26}, {10, 27}, {27, 12}, {27, 14}, {18, 27}, {27, 20}, 
                     {25, 27}, {28, 5}, {9, 28}, {28, 13}, {17, 28}, {19, 28}, 
                     {24, 28}, {29, 5}, {29, 6}, {8, 29}, {10, 29}, {12, 29},
                     {29, 14}, {18, 29}, {20, 29}, {25, 29}, {27, 29}, {2, 30}, 
                     {9, 30}, {11, 30}, {30, 15}, {21, 30}, {30, 22}, {30, 23}, 
                     {16, 31}, {21, 31}, {22, 31}, {31, 23}, {32, 2}, {32, 4}, 
                     {32, 6}, {32, 9}, {32, 11}, {32, 13}, {32, 17}, {32, 28}, 
                     {33, 3}, {33, 4}, {33, 6}, {33, 10}, {33, 14}, {33, 18}, 
                     {33, 29}, {9, 34}, {17, 34}, {34, 19}, {24, 34}, {34, 28}, 
                     {10, 35}, {35, 12}, {18, 35}, {35, 20}, {25, 35}, {35, 29}, 
                     {2, 36}, {36, 6}, {9, 36}, {11, 36}, {17, 36}, {26, 36}, 
                     {32, 36}, {34, 36}, {3, 37}, {37, 6}, {10, 37}, {12, 37}, 
                     {18, 37}, {27, 37}, {37, 29}, {33, 37}, {35, 37}, {5, 38}, 
                     {38, 7}, {8, 38}, {9, 38}, {10, 38}, {13, 38}, {38, 14}, 
                     {17, 38}, {18, 38}, {19, 38}, {20, 38}, {24, 38}, {25, 38}, 
                     {27, 38}, {28, 38}, {29, 38}, {34, 38}, {35, 38}, {2, 39}, 
                     {4, 39}, {13, 39}, {17, 39}, {19, 39}, {38, 39}, {40, 14}, 
                     {40, 18}, {40, 20}, {40, 27}]


def _theta(v_1, v_2):
    # Return angle between vectors v_1 and v_2 normalized to be in [0,1]
    angle = (2/np.pi) * np.arccos(np.abs(np.dot(v_1, v_2)))
    return angle


def _calc_s_T(i, j, k, a, b, c, evecs, v_xy):
    # Return connectivity between voxels i,j,k and a,b,c assuming a single 
    # tract (Eq 5)
    s_T = (1 - np.nanmin([_theta(evecs[i,j,k,:,0], v_xy[a-i+1,b-j+1,c-k+1,:]),
           _theta(evecs[a,b,c,:,0], v_xy[a-i+1,b-j+1,c-k+1,:])])) * \
          (1 - 2*_theta(evecs[i,j,k,:,0], evecs[a,b,c,:,0]))
    return s_T

    
def _calc_s_O(i, j, k, a, b, c, evals, evecs, v_xy):
    # Return connectivity between voxels i,j,k and a,b,c assuming
    # overlapping tracts (Eq 6)
    temp = np.zeros((4,2,3))
    temp[0,0,:] = evecs[i,j,k,:,0]
    temp[0,1,:] = evecs[a,b,c,:,0]
    temp[1,0,:] = evals[i,j,k,1] / evals[i,j,k,0] * evecs[i,j,k,:,1]
    temp[1,1,:] = evecs[a,b,c,:,0]
    temp[2,0,:] = evecs[i,j,k,:,0]
    temp[2,1,:] = evals[a,b,c,1] / evals[a,b,c,0] * evecs[a,b,c,:,1]
    temp[3,0,:] = evals[i,j,k,1] / evals[i,j,k,0] * evecs[i,j,k,:,1]
    temp[3,1,:] = evals[a,b,c,1] / evals[a,b,c,0] * evecs[a,b,c,:,1]
    vals = np.zeros(4)
    for x in range(4):
        vals[x] = _theta(temp[x,0,:], temp[x,1,:])
    v_O = temp[np.nanargmin(vals),:,:]
    s_O = (1 - np.nanmin([_theta(v_O[0,:], v_xy[a-i+1,b-j+1,c-k+1,:]),
          _theta(v_O[1,:], v_xy[a-i+1,b-j+1,c-k+1,:])])) * \
          (1 - 2*_theta(v_O[0,:], v_O[1,:]))
    return s_O


def _half_pos_nhood(i, j, k, evecs, v_xy):
    # Return half-neighborhood where dot product between principal diffusion
    # direction of voxel i,j,k and the direction between neighboring voxels
    # is positive
    pos_indices = []
    for a in range(3):
        for b in range(3):
            for c in range(3):
                if np.dot(evecs[i,j,k,:,0], v_xy[a,b,c]) > 0:
                    pos_indices.append(np.array([a,b,c]))
    return np.array(pos_indices) + [i,j,k] - [1,1,1]


def _half_neg_nhood(i, j, k, evecs, v_xy):
    # Return half-neighborhood where dot product between principal diffusion
    # direction of voxel i,j,k and the direction between neighboring voxels
    # is negative
    neg_indices = []
    for a in range(3):
        for b in range(3):
            for c in range(3):
                if np.dot(evecs[i,j,k,:,0], v_xy[a,b,c]) < 0:
                    neg_indices.append(np.array([a,b,c]))         
    return np.array(neg_indices) + [i,j,k] - [1,1,1]


def _calc_x_plus_s_T(i,j,k,evecs,v_xy):
    # Return x plus for s_T
    pos_indices = _half_pos_nhood(i,j,k,evecs,v_xy)
    max_s_T = -np.inf
    for idx in pos_indices:
        temp_s_T = _calc_s_T(i,j,k,idx[0],idx[1],idx[2],evecs,v_xy)
        if temp_s_T > max_s_T:
            max_s_T = temp_s_T
            argmax = idx
    return argmax, max_s_T


def _calc_x_minus_s_T(i,j,k,evecs,v_xy):
    # Return x minus for s_T
    neg_indices = _half_neg_nhood(i,j,k,evecs,v_xy)
    max_s_T = -np.inf
    for idx in neg_indices:
        temp_s_T = _calc_s_T(i,j,k,idx[0],idx[1],idx[2],evecs,v_xy)
        if temp_s_T > max_s_T:
            max_s_T = temp_s_T
            argmax = idx
    return argmax, max_s_T


def _calc_x_plus_s_O(i,j,k,evals,evecs,v_xy):
    # Return x plus for s_O
    pos_indices = _half_pos_nhood(i,j,k,evecs,v_xy)
    max_s_O = -np.inf
    for idx in pos_indices:
        temp_s_O = _calc_s_O(i,j,k,idx[0],idx[1],idx[2],evals,evecs,v_xy)
        if temp_s_O > max_s_O:
            max_s_O = temp_s_O
            argmax = idx
    return argmax, max_s_O


def _calc_x_minus_s_O(i,j,k,evals,evecs,v_xy):
    # Return x minus for s_O
    neg_indices = _half_neg_nhood(i,j,k,evecs,v_xy)
    max_s_O = -np.inf
    for idx in neg_indices:
        temp_s_O = _calc_s_O(i,j,k,idx[0],idx[1],idx[2],evals,evecs,v_xy)
        if temp_s_O > max_s_O:
            max_s_O = temp_s_O
            argmax = idx
    return argmax, max_s_O


def _calc_c_l(i, j, k, l, m, evecs, fiber_dir, c_C):
    # Return direction index (Eqs 8, 9)
    if m == None:
        c_l = np.linalg.norm(fiber_dir[i,j,k,:,l]) * (1 - c_C *
              _theta(evecs[i,j,k,:,0], fiber_dir[i,j,k,:,l] /
              np.linalg.norm(fiber_dir[i,j,k,:,l])))
    else:
        comp_dir = np.stack((fiber_dir[i,j,k,:,l] + fiber_dir[i,j,k,:,m],
                             fiber_dir[i,j,k,:,l] - fiber_dir[i,j,k,:,m]))
        comp_dir = comp_dir[np.argmax(np.linalg.norm(comp_dir, axis=1)),:]
        comp_dir = (comp_dir / np.linalg.norm(comp_dir) * 
                    (np.linalg.norm(fiber_dir[i,j,k,:,l]) +
                     np.linalg.norm(fiber_dir[i,j,k,:,m])) / 2)
        c_l = np.linalg.norm(comp_dir) * (1 - c_C * 
              _theta(evecs[i,j,k,:,0], comp_dir /
              np.linalg.norm(comp_dir)))
    return c_l


def _calc_V1(d_T, d_O, d_I, u_l, u_lm, c_l, c_lm, c_I, fiber_p, 
             tract_pair_sets, N_t, N_o, brain_mask):
    # Return energy using the unary term only (Eq 11)
    xs, ys, zs = d_T.shape
    MRF_V1 = np.zeros((xs, ys, zs, N_t + N_o))
    print('Calculating V1')
    for i in range(xs):
        print(str(np.round((i / xs)*100, 0)) + ' %', end="\r")
        for j in range(ys):
            for k in range(zs):
                if brain_mask[i,j,k]:
                        
                    # Calculate isotropic energy
                    MRF_V1[i,j,k,0] = c_I * d_I[i,j,k] * u_l[i,j,k,0]
                        
                    # Calculate individual tract energies
                    for idx in range(1,N_t):
                        l = idx
                        if fiber_p[i,j,k,l] == 0:
                            MRF_V1[i,j,k,idx] = np.nan
                        else:
                            MRF_V1[i,j,k,idx] = d_T[i,j,k] * u_l[i,j,k,l] * \
                                                c_l[i,j,k,l]

                    # Calculate overlapping tract energies
                    for idx in range(N_t,N_t+N_o):
                        l,m = tract_pair_sets[idx - N_t]
                        if fiber_p[i,j,k,l] == 0 or fiber_p[i,j,k,m] == 0:
                            MRF_V1[i,j,k,idx] = np.nan
                        else:
                            MRF_V1[i,j,k,idx] = d_O[i,j,k] * \
                                                u_lm[i,j,k,idx-N_t] * \
                                                c_lm[i,j,k,idx-N_t]
    return MRF_V1
            

def _calc_U(prev_iter_U, d_T, d_O, d_I, u_l, u_lm, c_l, c_lm, c_I, fiber_p, 
            tract_pair_sets, s_I, s_T_x_p, s_T_x_m, s_O_x_m, s_O_x_p, 
            brain_mask, N_t, N_o, x_m_s_T, x_p_s_T, x_m_s_O, x_p_s_O):
    # Return total energy (Eq 12)
    tract_pair_array = np.array([list(i) for i in tract_pair_sets])
    xs, ys, zs = d_T.shape
    curr_U = np.zeros((xs, ys, zs, N_t + N_o))
    
    for i in range(xs):
        print(str(np.round((i / xs)*100, 0)) +   ' %', end="\r")
        for j in range(ys):
            for k in range(zs):
                if brain_mask[i,j,k]:
                    
                    # Check isotropic energy
                    curr_U[i,j,k,0] = c_I * d_I[i,j,k] * u_l[i,j,k,0] + \
                                      (s_I * (np.nansum(prev_iter_U[i-1:i+2,
                                       j-1:j+2,k-1:k+2,0]) - prev_iter_U[i,j,
                                       k,0]) / 26)
                    
                    
                    # Check individual tract energy
                    for l in range(1,N_t):
                        if fiber_p[i,j,k,l] == 0:
                            curr_U[i,j,k,l] = np.nan
                        else:
                            idx = np.concatenate(([l], np.where(np.any(
                                                 tract_pair_array
                                                 == l, axis=1))[0]+N_t))
                            curr_U[i,j,k,l] = d_T[i,j,k] * u_l[i,j,k,l] * \
                                              c_l[i,j,k,l] + \
                                              0.5 * (s_T_x_p[i,j,k] * 
                                              np.nanmax(prev_iter_U[
                                                        x_p_s_T[i,j,k,0],
                                                        x_p_s_T[i,j,k,1],
                                                        x_p_s_T[i,j,k,2],
                                                        idx])) + \
                                              0.5 * (s_T_x_m[i,j,k] * 
                                              np.nanmax(prev_iter_U[
                                                        x_m_s_T[i,j,k,0],
                                                        x_m_s_T[i,j,k,1],
                                                        x_m_s_T[i,j,k,2],
                                                        idx]))
                                                
                    # Check overlapping tract energy
                    for idx in range(N_t,N_t+N_o):
                        l,m = tract_pair_sets[idx-N_t]
                        if fiber_p[i,j,k,l] == 0 or fiber_p[i,j,k,m] == 0:
                            curr_U[i,j,k,idx] = np.nan
                        else:
                            curr_U[i,j,k,idx] = d_O[i,j,k] * \
                                                u_lm[i,j,k,idx-N_t] * \
                                                c_lm[i,j,k,idx-N_t] + \
                                                0.5 * (s_O_x_p[i,j,k] *
                                                np.nanmax([prev_iter_U[
                                                           x_p_s_O[i,j,k,0],
                                                           x_p_s_O[i,j,k,1],
                                                           x_p_s_O[i,j,k,2],
                                                           idx],
                                                           prev_iter_U[
                                                           x_p_s_O[i,j,k,0],
                                                           x_p_s_O[i,j,k,1],
                                                           x_p_s_O[i,j,k,2],
                                                           l],
                                                           prev_iter_U[
                                                           x_p_s_O[i,j,k,0],
                                                           x_p_s_O[i,j,k,1],
                                                           x_p_s_O[i,j,k,2],
                                                           m]])) + \
                                               0.5 * (s_O_x_m[i,j,k] *
                                               np.nanmax([prev_iter_U[
                                                          x_m_s_O[i,j,k,0],
                                                          x_m_s_O[i,j,k,1],
                                                          x_m_s_O[i,j,k,2],
                                                          idx],
                                                          prev_iter_U[
                                                          x_m_s_O[i,j,k,0],
                                                          x_m_s_O[i,j,k,1],
                                                          x_m_s_O[i,j,k,2],
                                                          l],
                                                          prev_iter_U[
                                                          x_m_s_O[i,j,k,0],
                                                          x_m_s_O[i,j,k,1],
                                                          x_m_s_O[i,j,k,2],
                                                          m]]))
    return curr_U


def _calc_segmentation(U):
    # Return hard segmentation based on MRF energy U  
    U_temp = np.copy(U)
    U_temp[np.isnan(U_temp)] = -np.inf
    segmentation = np.argmax(U_temp, axis = 3)
    return segmentation


def calc_posterior_probability(l, U, wm_atlas, g0 = None):
    # Return posterior probability of tract l from MRF energy U (Eq 15)               
    if wm_atlas == 1:
        N_t = 23
        tract_pair_array = np.array([list(i) for i in tract_pair_sets_1])
    elif wm_atlas == 2:
        N_t = 41
        tract_pair_array = np.array([list(i) for i in tract_pair_sets_2])
    if g0 == None:
        g0 = N_t
    idx = np.concatenate(([l], np.where(np.any(tract_pair_array 
                                               == l, axis=1))[0]+N_t))
    posterior_l = (np.nansum(np.exp(g0*U[:,:,:,idx]), axis=3) /
                   np.nansum(np.exp(g0*U),axis=3))
    return posterior_l
    

[docs]def dots_segmentation(tensor_image, mask, atlas_dir, wm_atlas = 1, max_iter = 25, convergence_threshold = 0.005, s_I = 1/42, c_O = 0.5, max_angle = 67.5, save_data = False, overwrite = False, output_dir = None, file_name = None): """DOTS segmentation Segment major white matter tracts in diffusion tensor images using Diffusion Oriented Tract Segmentation (DOTS) algorithm. Parameters ---------- tensor_image: niimg Input image containing the diffusion tensor coefficients in the following order: volumes 0-5: D11, D22, D33, D12, D13, D23 mask: niimg Binary brain mask image which limits computation to the defined volume. atlas_dir: str Path to directory where the DOTS atlas information is stored. The atlas information should be stored in a subdirectory called 'DOTS_atlas' as generated by nighres.data.download_DOTS_atlas(). wm_atlas: int, optional Define which white matter atlas to use. Option 1 for 23 tracts [2]_ and option 2 for 39 tracts [1]_. (default is 1) max_iter: int, optional Maximum number of iterations in the conditional modes algorithm. (default is 20) convergence_threshold: float, optional Threshold for when the iterated conditonal modes algorithm is considered to have converged. Defined as the fraction of labels that change during one step of the algorithm. (default is 0.002) s_I: float, optional Parameter controlling how isotropic label energies propagate to their neighborhood. (default is 1/42) c_O: float, optional Weight parameter for unclassified white matter atlas prior. (default is 1/2) max_angle: float, optional Maximum angle (in degrees) between principal tensor directions before connectivity coefficient c becomes negative. Possible values between 0 and 90. (default is 67.5) save_data: bool, optional Save output data to file. (default is False) overwrite: bool, optional Overwrite existing results. (default is False) output_dir: str, optional Path to desired output directory, will be created if it doesn't exist. file_name: str, optional Desired base name for output files without file extension, suffixes will be added. Returns ---------- dict Dictionary collecting outputs under the following keys (type of output files in brackets) * segmentation (array_like): Hard segmentation of white matter. * posterior (array_like): POsterior probabilities of tracts. Notes ---------- Algorithm details can be found in the references below. References ---------- .. [1] Bazin, Pierre-Louis, et al. "Direct segmentation of the major white matter tracts in diffusion tensor images." Neuroimage (2011) doi: https://doi.org/10.1016/j.neuroimage.2011.06.020 .. [2] Bazin, Pierre-Louis, et al. "Efficient MRF segmentation of DTI white matter tracts using an overlapping fiber model." Proceedings of the International Workshop on Diffusion Modelling and Fiber Cup (2009) """ print('\nDOTS white matter tract segmentation') # make sure that saving related parameters are correct if save_data: output_dir = _output_dir_4saving(output_dir, tensor_image) seg_file = os.path.join(output_dir, _fname_4saving(module=__name__,file_name=file_name, rootfile=tensor_image, suffix='dots-seg')) proba_file = os.path.join(output_dir, _fname_4saving(module=__name__,file_name=file_name, rootfile=tensor_image, suffix='dots-proba')) if overwrite is False \ and os.path.isfile(seg_file) and os.path.isfile(proba_file) : print("skip computation (use existing results)") output = {'segmentation': seg_file, 'posterior': proba_file} return output # For external tools: dipy try: from dipy.align.transforms import AffineTransform3D from dipy.align.imaffine import MutualInformationMetric, AffineRegistration except ImportError: print('Error: Dipy could not be imported, it is required' +' in order to run DOTS segmentation. \n (aborting)') return None # Ignore runtime warnings that arise from trying to divide by 0/nan # and all nan slices np.seterr(divide = 'ignore', invalid = 'ignore') # Define the scalar constant c_I c_I = 1/2 # Define constant c_C that is used in direction coefficient calculation c_C = 90 / max_angle # Create an array containing the directions between neighbors v_xy = np.zeros((3, 3, 3, 3)) for i in range(3): for j in range(3): for k in range(3): if (i,j,k) == (1,1,1): v_xy[i,j,k,:] = np.nan else: x = np.array([1,0,0]) y = np.array([0,1,0]) z = np.array([0,0,1]) c = np.array([1,1,1]) v_xy[i,j,k,:] = i*x + y*j + z*k - c v_xy[i,j,k,:] = v_xy[i,j,k,:] / \ np.linalg.norm(v_xy[i,j,k,:]) # Load tensor image tensor_volume = load_volume(tensor_image).get_data() # Load brain mask brain_mask = load_volume(mask).get_data().astype(bool) # Get dimensions of diffusion data xs, ys, zs, _ = tensor_volume.shape DWI_affine = load_volume(tensor_image).affine # Calculate diffusion tensor eigenvalues and eigenvectors tenfit = np.zeros((xs, ys, zs, 3, 3)) tenfit[:,:,:,0,0] = tensor_volume[:,:,:,0] tenfit[:,:,:,1,1] = tensor_volume[:,:,:,1] tenfit[:,:,:,2,2] = tensor_volume[:,:,:,2] tenfit[:,:,:,0,1] = tensor_volume[:,:,:,3] tenfit[:,:,:,1,0] = tensor_volume[:,:,:,3] tenfit[:,:,:,0,2] = tensor_volume[:,:,:,4] tenfit[:,:,:,2,0] = tensor_volume[:,:,:,4] tenfit[:,:,:,1,2] = tensor_volume[:,:,:,5] tenfit[:,:,:,2,1] = tensor_volume[:,:,:,5] tenfit[np.isnan(tenfit)] = 0 evals, evecs = np.linalg.eig(tenfit) evals, evecs = np.real(evals), np.real(evecs) for i in range(xs): for j in range(ys): for k in range(zs): idx = np.argsort(evals[i,j,k,:])[::-1] evecs[i,j,k,:,:] = evecs[i,j,k,:,idx].T evals[i,j,k,:] = evals[i,j,k,idx] evals[~brain_mask] = 0 evecs[~brain_mask] = 0 # Calculate FA R = tenfit / np.trace(tenfit, axis1=3, axis2=4)[:,:,:,np.newaxis,np.newaxis] FA = np.sqrt(0.5 * (3 - 1/(np.trace(np.matmul(R,R), axis1=3, axis2=4)))) FA[np.isnan(FA)] = 0 if wm_atlas == 1: # Use smaller atlas # Indices are # 0 for isotropic regions # 1 for unclassified white matter # 2-22 for individual tracts # 22-73 for overlapping tracts N_t = 23 N_o = 50 atlas_path = os.path.join(atlas_dir, 'DOTS_atlas') fiber_p = nb.load(os.path.join(atlas_path,'fiber_p.nii.gz')).get_data() max_p = np.nanmax(fiber_p[:,:,:,2::], axis = 3) fiber_dir = nb.load(os.path.join(atlas_path, 'fiber_dir.nii.gz') ).get_data() atlas_affine = nb.load(os.path.join(atlas_path,'fiber_p.nii.gz')).affine del_idx = [9,10,13,14,15,16,21,26,27,28,29,30,31,32,33,36,37,38] fiber_p = np.delete(fiber_p, del_idx, axis = 3) fiber_dir = np.delete(fiber_dir, del_idx, axis = 4) tract_pair_sets = tract_pair_sets_1 elif wm_atlas == 2: # Use full atlas # Indices are # 0 for isotropic regions # 1 for unclassified white matter # 2-40 for individual tracts # 41-224 for overlapping tracts N_t = 41 N_o = 185 atlas_path = os.path.join(atlas_dir, 'DOTS_atlas') fiber_p = nb.load(os.path.join(atlas_path,'fiber_p.nii.gz')).get_data() max_p = np.nanmax(fiber_p[:,:,:,2::], axis = 3) fiber_dir = nb.load(os.path.join(atlas_path, 'fiber_dir.nii.gz') ).get_data() atlas_affine = nb.load(os.path.join(atlas_path,'fiber_p.nii.gz')).affine tract_pair_sets = tract_pair_sets_2 print('Diffusion and atlas data loaded ') # Register atlas priors to DWI data with DiPy print('Registering atlas priors to DWI data') metric = MutualInformationMetric(nbins = 32, sampling_proportion = None) affreg = AffineRegistration(metric = metric, level_iters = [10000,1000,100], sigmas = [3.0,1.0,0.0], factors = [4,2,1]) transformation = affreg.optimize(FA, max_p, AffineTransform3D(), params0=None, static_grid2world=DWI_affine, moving_grid2world=atlas_affine, starting_affine='mass') reg_fiber_p = np.zeros((xs, ys, zs, fiber_p.shape[-1])) for i in range(fiber_p.shape[-1]): reg_fiber_p[:,:,:,i] = transformation.transform(fiber_p[:,:,:,i]) fiber_p = reg_fiber_p reg_fiber_dir = np.zeros((xs, ys, zs, 3, fiber_dir.shape[-1])) for i in range(fiber_dir.shape[-1]): for j in range(3): reg_fiber_dir[:,:,:,j,i] = transformation.transform( fiber_dir[:,:,:,j,i]) fiber_dir = reg_fiber_dir fiber_p[~brain_mask,0] = 1 fiber_p[~brain_mask,1:] = 0 fiber_dir[~brain_mask] = 0 print('Finished registration of atlas priors to DWI data') # Calculate diffusion type indices print('Calculating d_T, d_O, d_I') d_T = (evals[:,:,:,0] - evals[:,:,:,1]) / evals[:,:,:,0] d_O = (evals[:,:,:,0] - evals[:,:,:,2]) / evals[:,:,:,0] d_I = evals[:,:,:,2] / evals[:,:,:,0] print('Finished calculating d_T, d_O, d_I') # Calculate xplus and xminus x_m_s_T = np.zeros((xs,ys,zs,3)) x_p_s_T = np.zeros((xs,ys,zs,3)) x_m_s_O = np.zeros((xs,ys,zs,3)) x_p_s_O = np.zeros((xs,ys,zs,3)) s_T_x_m = np.zeros((xs,ys,zs)) s_T_x_p = np.zeros((xs,ys,zs)) s_O_x_m = np.zeros((xs,ys,zs)) s_O_x_p = np.zeros((xs,ys,zs)) print('Calculating x^+, x^-, s_T, s_O') for i in range(1,xs-1): print(str(np.round((i / xs)*100, 0)) + ' %', end="\r") for j in range(1,ys-1): for k in range(1,zs-1): if brain_mask[i,j,k]: x_m_s_T[i,j,k,:], s_T_x_m[i,j,k] = _calc_x_minus_s_T(i,j,k,evecs,v_xy) x_p_s_T[i,j,k,:], s_T_x_p[i,j,k] = _calc_x_plus_s_T(i,j,k,evecs,v_xy) x_m_s_O[i,j,k,:], s_O_x_m[i,j,k] = _calc_x_minus_s_O(i,j,k,evals,evecs,v_xy) x_p_s_O[i,j,k,:], s_O_x_p[i,j,k] = _calc_x_plus_s_O(i,j,k,evals,evecs,v_xy) x_p_s_T = x_p_s_T.astype(int) x_m_s_T = x_m_s_T.astype(int) x_p_s_O = x_p_s_T.astype(int) x_m_s_O = x_m_s_T.astype(int) print('Finished calculating x^+, x^-, s_T, s_O') # Calculate shape prior arrays print('Calculating u_l, u_lm') u_l = fiber_p**2 / np.nansum(fiber_p, axis=3)[:,:,:,np.newaxis] u_lm = np.zeros((xs, ys, zs, len(tract_pair_sets))) for idx in range(len(tract_pair_sets)): l,m = tract_pair_sets[idx] u_lm[:,:,:,idx] = fiber_p[:,:,:,l]*fiber_p[:,:,:,m]*(fiber_p[:,:,:,l] + fiber_p[:,:,:,m]) / \ np.nansum(fiber_p, axis=3) u_l[:,:,:,1] *= c_O # Scale by weight parameter print('Finished calculating u_l, u_lm') # Calculate direction coefficients c_l = np.zeros((xs, ys, zs, N_t))*np.nan c_lm = np.zeros((xs, ys, zs, len(tract_pair_sets)))*np.nan print('Calculating c_l, c_lm') for i in range(xs): print(str(np.round((i / xs)*100, 0)) + ' %', end="\r") for j in range(ys): for k in range(zs): for l in range(1,N_t): if fiber_p[i,j,k,l] != 0: c_l[i,j,k,l] = _calc_c_l(i,j,k,l,None,evecs, fiber_dir,c_C) for idx in range(len(tract_pair_sets)): l,m = tract_pair_sets[idx] if fiber_p[i,j,k,l] != 0 and fiber_p[i,j,k,m] != 0: c_lm[i,j,k,idx] = _calc_c_l(i,j,k,l,m,evecs, fiber_dir,c_C) print('Finished calculating c_l, c_lm') # Mask arrays d_T[~brain_mask] = np.nan d_O[~brain_mask] = np.nan d_I[~brain_mask] = 1 fiber_p[~brain_mask,0] = 1 fiber_p[~brain_mask,1:] = np.nan fiber_dir[~brain_mask] = np.nan c_l[~brain_mask] = np.nan c_lm[~brain_mask] = np.nan u_l[~brain_mask] = np.nan u_l[~brain_mask,0] = 1 u_lm[~brain_mask] = np.nan s_T_x_p[~brain_mask] = np.nan s_T_x_m[~brain_mask] = np.nan s_O_x_p[~brain_mask] = np.nan s_O_x_m[~brain_mask] = np.nan # Only ROIs where p != 0 are of interest u_l[u_l == 0] = np.nan u_lm[u_lm == 0] = np.nan # Calculate energy based on unary term only MRF_V1 = _calc_V1(d_T, d_O, d_I, u_l, u_lm, c_l, c_lm, c_I, fiber_p, tract_pair_sets, N_t, N_o, brain_mask) # Maximize U print('Maximizing U') curr_U = np.copy(MRF_V1) iteration = 0 change_in_labels = np.inf while iteration < max_iter and change_in_labels > convergence_threshold: at = time.time() prev_U = np.copy(curr_U) prev_segmentation = _calc_segmentation(prev_U) iteration += 1 print('Iteration '+str(iteration)) curr_U = _calc_U(prev_U, d_T, d_O, d_I, u_l, u_lm, c_l, c_lm, c_I, fiber_p, tract_pair_sets, s_I, s_T_x_p, s_T_x_m, s_O_x_m, s_O_x_p, brain_mask, N_t, N_o, x_m_s_T, x_p_s_T, x_m_s_O, x_p_s_O) curr_segmentation = _calc_segmentation(curr_U) change_in_labels = (np.nansum(prev_segmentation != curr_segmentation) / np.nansum(brain_mask)) bt = time.time() print('Iteration '+str(iteration)+' took '+str(bt-at)+' seconds') print('Total U = '+str(np.nansum(curr_U))) print('Fraction of changed labels = '+str(change_in_labels)) print('Finished maximizing U') # Calculate posterior probabilities print('Calculating posterior probabilities') fiber_posterior = np.zeros(fiber_p.shape) curr_U[curr_U == 0] = np.nan for l in range(N_t): print(str(np.round((l / N_t)*100, 0)) + ' %', end="\r") fiber_posterior[:,:,:,l] = calc_posterior_probability(l, curr_U, 1) fiber_posterior[fiber_posterior == 0] = np.nan fiber_posterior[np.isinf(fiber_posterior)] = np.nan curr_U[np.isnan(curr_U)] = 0 print('Finished calculating posterior probabilities') # Save results if save_data: save_volume(seg_file, nb.Nifti1Image(curr_segmentation, DWI_affine)) save_volume(proba_file, nb.Nifti1Image(fiber_posterior, DWI_affine)) return {'segmentation': seg_file, 'posterior': proba_file} else: # Return results return {'segmentation': curr_segmentation, 'posterior': fiber_posterior}