#!/usr/bin/env fslpython
#   Copyright (C) 2016 University of Oxford
#   Part of FSL - FMRIB's Software Library
#   http://www.fmrib.ox.ac.uk/fsl
#   fsl@fmrib.ox.ac.uk
#
#   Developed at FMRIB (Oxford Centre for Functional Magnetic Resonance
#   Imaging of the Brain), Department of Clinical Neurology, Oxford
#   University, Oxford, UK
#
#
#   LICENCE
#
#   FMRIB Software Library, Release 6.0 (c) 2018, The University of
#   Oxford (the "Software")
#
#   The Software remains the property of the Oxford University Innovation
#   ("the University").
#
#   The Software is distributed "AS IS" under this Licence solely for
#   non-commercial use in the hope that it will be useful, but in order
#   that the University as a charitable foundation protects its assets for
#   the benefit of its educational and research purposes, the University
#   makes clear that no condition is made or to be implied, nor is any
#   warranty given or to be implied, as to the accuracy of the Software,
#   or that it will be suitable for any particular purpose or for use
#   under any specific conditions. Furthermore, the University disclaims
#   all responsibility for the use which is made of the Software. It
#   further disclaims any liability for the outcomes arising from using
#   the Software.
#
#   The Licensee agrees to indemnify the University and hold the
#   University harmless from and against any and all claims, damages and
#   liabilities asserted by third parties (including claims for
#   negligence) which arise directly or indirectly from the use of the
#   Software or the sale of any products based on the Software.
#
#   No part of the Software may be reproduced, modified, transmitted or
#   transferred in any form or by any means, electronic or mechanical,
#   without the express permission of the University. The permission of
#   the University is not required if the said reproduction, modification,
#   transmission or transference is done without financial return, the
#   conditions of this Licence are imposed upon the receiver of the
#   product, and all original and amended source code is included in any
#   transmitted product. You may be held legally responsible for any
#   copyright infringement that is caused or encouraged by your failure to
#   abide by these terms and conditions.
#
#   You are not permitted under this Licence to use this Software
#   commercially. Use for which any financial return is received shall be
#   defined as commercial use, and includes (1) integration of all or part
#   of the source code or the Software into a product for sale or license
#   by or on behalf of Licensee to third parties or (2) use of the
#   Software or any derivative of it for research with the final aim of
#   developing software products for sale or license to a third party or
#   (3) use of the Software or any derivative of it for research with the
#   final aim of developing non-software products for sale or license to a
#   third party, or (4) use of the Software to provide any service to an
#   external organisation for which payment is received. If you are
#   interested in using the Software commercially, please contact Oxford
#   University Innovation ("OUI"), the technology transfer company of the
#   University, to negotiate a licence. Contact details are:
#   fsl@innovation.ox.ac.uk quoting Reference Project 9564, FSL.
"""
Segmentation method based on variable kernel density techniques - a generalisation of kNN

Uses sklearn.neighbors library for k-nearest-neighbour calculations

Training
  - assume pre-processing is done beforehand (bias-field correction + registration)
  - select candidate training voxels (all lesion + some/all non-lesion)
  - calculate features  (intensities, coordinates, PVE)
  - perform intensity normalisation per feature (variance-based; robust-range; IQR/robust-var)
  - build NN tree

Classification
  - calculate features
  - normalise features
  - look up kNN from NN tree
  - compute probability using appropriate kernel (uses distances and class labels of kNN)
  - store probability in each voxel
"""

import getopt
import os
import pickle
import sys
import textwrap as tw

import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy import ndimage

import fsl.data.image as fslimage


#######################################################################################

# Global options


class MyOptions:
    pass


GlobalOpts = MyOptions()

#######################################################################################

# Generic support functions


# The follow support functions are needed because numpy is very inconsistent and fussy
#  about whether vectors are of shape (n,1) or (n,)
def make2d(mat):
    if not isinstance(mat, np.ndarray):
        return mat
    if len(mat.shape) == 1:  # take care of the degenerate, single modality case
        return mat.reshape(len(mat), 1)
    else:
        return mat


def make1d(vec):
    if not isinstance(vec, np.ndarray):
        return vec
    if len(vec.shape) == 1:
        return vec
    return vec[:, 0]


# The following support functions are needed because numpy can't deal well with
#  null elements in calls to hstack or vstack
def h_stack(x):  # make it work better when null elements exist
    y = [f for f in x if f.size > 0]  # strip out null parts
    if y:
        return np.hstack(y)
    else:
        return np.array([])


def v_stack(x):  # make it work better when null elements exist
    y = [f for f in x if f.size > 0]  # strip out null parts
    if y:
        return np.vstack(y)
    else:
        return np.array([])


# Convert a list of lists into an ndarray
def list2array(x):
    nrows = len(x)
    ncols = len(x[0])
    arr = np.zeros([nrows, ncols])
    row = 0
    for r in x:
        col = 0
        for e in r:
            arr[row, col] = e
            col += 1
        row += 1
    return arr


def nonzeroroi(x):
    ids = np.nonzero(x)
    xmin = min(ids[0])
    xmax = max(ids[0])
    ymin = min(ids[1])
    ymax = max(ids[1])
    zmin = min(ids[2])
    zmax = max(ids[2])
    return [xmin, xmax, ymin, ymax, zmin, zmax]


def extractroi(x, roi):
    return x[roi[0] : roi[1] + 1, roi[2] : roi[3] + 1, roi[4] : roi[5] + 1]


def vox2mni_mat(flirtmat, im):
    vox2mm = np.identity(4)
    pixdims = im.header.get_zooms()
    for n in range(3):
        vox2mm[n, n] = pixdims[n]
    [qform, qcode] = im.header.get_qform(True)
    [sform, scode] = im.header.get_sform(True)
    radiological = True
    if qcode > 0 and np.linalg.det(qform) > 0:
        radiological = False
    if scode > 0 and np.linalg.det(sform) > 0:
        radiological = False
    if not radiological:  # convert to nifti voxel coordinate convention
        swapmat = np.identity(4)
        swapmat[0, 0] = -1
        swapmat[0, 3] = im.shape[0] - 1
        vox2mm = np.dot(vox2mm, swapmat)
    flirtmat = np.dot(flirtmat, vox2mm)
    retmat = np.dot(GlobalOpts.mnisform, flirtmat)
    return retmat


# Flexible extraction of values from an ND array based on an array of voxel coordinates
# This will extract timeseries if the coordinates are short by one
def getvals(valim, pts):
    retval = None
    if len(valim.shape) == pts.shape[0] or len(valim.shape) == pts.shape[0] + 1:
        retval = valim[tuple(pts)]  # each coordinate is a column
    if len(valim.shape) == pts.shape[1] or len(valim.shape) == pts.shape[1] + 1:
        retval = valim[tuple(pts.transpose())]  # each coordinate is a row
    return make2d(retval)


def put_vox_vals(pts, vals, refim):
    newim = np.zeros(refim.shape, dtype=np.float32)
    if len(newim.shape) == pts.shape[0] or len(newim.shape) == pts.shape[0] + 1:
        newim[tuple(pts)] = make1d(vals)  # each coordinate is a column
        return newim
    if len(newim.shape) == pts.shape[1] or len(newim.shape) == pts.shape[1] + 1:
        newim[tuple(pts.transpose())] = make1d(vals)  # each coordinate is a row
        return newim
    return newim


def put_les_vals(cluster_im, vals):
    newim = cluster_im.copy().astype(np.float32)
    nb_labels = cluster_im.max()
    mapper0 = np.zeros([nb_labels + 1], dtype=bool)
    for n in range(1, nb_labels + 1):
        mapper = mapper0.copy()
        mapper[n] = True
        chvox = mapper[cluster_im]
        newim[chvox] = vals[n - 1]  # there is no entry for cluster "zero" in vals[]
    return newim


def sublist(list, indices):
    if list:
        if list[0]:
            return [list[n] for n in indices]
    return list


def transform_coords(mat, xx, yy, zz):
    newxx = xx * mat[0, 0] + yy * mat[0, 1] + zz * mat[0, 2] + mat[0, 3]
    newyy = xx * mat[1, 0] + yy * mat[1, 1] + zz * mat[1, 2] + mat[1, 3]
    newzz = xx * mat[2, 0] + yy * mat[2, 1] + zz * mat[2, 2] + mat[2, 3]
    return [newxx, newyy, newzz]


def generate_mni_coords(affmat, imobj, label_im=None):
    if affmat is None:
        affmat = np.identity(4)
    vox2mni = vox2mni_mat(affmat, imobj)
    nx, ny, nz = imobj.shape
    xx, yy, zz = np.mgrid[0:nx, 0:ny, 0:nz]
    [x, y, z] = transform_coords(vox2mni, xx, yy, zz)
    if label_im is None:
        return [x, y, z]
    else:
        returnvals = np.array([])
        rvals = np.asarray(ndimage.mean(xx, label_im, range(1, label_im.max() + 1)))
        returnvals = h_stack([returnvals, make2d(rvals)])
        rvals = np.asarray(ndimage.mean(yy, label_im, range(1, label_im.max() + 1)))
        returnvals = h_stack([returnvals, make2d(rvals)])
        rvals = np.asarray(ndimage.mean(zz, label_im, range(1, label_im.max() + 1)))
        returnvals = h_stack([returnvals, make2d(rvals)])
        return returnvals


#######################################################################################

# Main mathematical functions of interest


def select_coords(mask, numpts=1, allpoints=False, repetitions=False):
    coords = np.transpose(np.nonzero(mask))
    if allpoints:
        return coords
    ncoords = coords.shape[0]
    npts = min([ncoords, numpts])
    if repetitions:
        idx = np.random.random_integers(ncoords - 1, size=npts)
    else:
        idx = np.random.permutation(range(ncoords))[:npts]
    selcoords = coords[idx]
    return selcoords


def choose_pts(imdata, numpts=1, nonlespts=1, lesmask=None, mode=None):
    if lesmask is None:
        lesmask = np.array([])
    if mode is None:
        mode = ['any', 'equalpoints']
    # skip to else if no lesion mask given (as this would then be a query subject), but mode is then ignored
    if lesmask.any():
        lescoords = select_coords(
            lesmask, numpts=numpts, allpoints="equalpoints" in mode
        )
        lesmask2 = lesmask
        surcoords = np.array([])
        if "surround" in mode:
            bordersize = 5
        if "noborder" in mode:
            bordersize = 3
        if "surround" in mode or "noborder" in mode:
            bsizevec = [bordersize, bordersize, bordersize]
            bsizevec[np.argmax(GlobalOpts.pixdims)] = 1
            lesmask2 = ndimage.grey_dilation(lesmask, footprint=np.ones(bsizevec))
        if "surround" in mode:
            surcoords = select_coords(
                (lesmask2 - lesmask) * (imdata > 0),
                numpts=lescoords.shape[0],
                allpoints=False,
            )
        brnpts = nonlespts
        # force equal number of non-lesion points (unless explicit number of points specified)
        if "equalpoints" in mode:
            brnpts = lescoords.shape[0]
        brncoords = select_coords(
            imdata - imdata * lesmask2, numpts=brnpts, allpoints=False
        )
        pts = v_stack([lescoords, surcoords, brncoords])
    else:
        pts = select_coords(imdata, allpoints=True)
    return pts


def normalise_data(data, method="var"):
    if method == "var":
        if len(data.shape) == 2:  # do it for each column separately
            return (data - np.mean(data, axis=0)) / np.std(data, axis=0)
        if len(data.shape) == 3:  # a single 3D image, so do it for everything as once
            return (data - data[data > 0].mean()) / data[data > 0].std()
    return data


def call_knn(k_tree, check_dat):
    distances, indices = k_tree.kneighbors(check_dat)
    return [indices, distances]


def calc_pvals(labels, distvals, kernel=None, method="knn"):
    # input arrays are (num pts) x (num neighbours)
    numNN = labels.shape[1]
    if method == "knn":
        pvals = np.sum(labels, axis=1)
        pvals /= float(numNN)
    elif method == "kernel":
        kdists = np.array(map(kernel, distvals))
        pvals = np.sum(np.squeeze(labels) * np.squeeze(kdists), axis=1)
        pvals /= np.maximum(np.sum(kdists, axis=1), 1e-100)
    return pvals


def mni_coords(sform, vox_coords):
    mnic = np.dot(sform[:3, :3], vox_coords.transpose()) + np.reshape(sform[:3, 3], [3, 1])
    return mnic.transpose()


def remove_lesions(label_im, inclusion_list, change_labels=False):
    lesmask = (label_im > 0).astype(np.float32)
    sizes = ndimage.sum(lesmask, label_im, range(label_im.max() + 1))
    exclusion_list = [not x for x in inclusion_list]
    remove_pixel = exclusion_list[label_im]
    label_im[remove_pixel] = 0
    if change_labels:
        labels = np.unique(label_im)
        label_im = np.searchsorted(labels, label_im)
    return label_im


def apply_size_thresh(label_im, lesmask, size_thresh):
    sizelist = ndimage.sum(lesmask, label_im, range(label_im.max() + 1))
    sizes = np.asarray(sizelist)
    mask_size = sizes <= size_thresh
    remove_pixel = mask_size[label_im]
    label_im[remove_pixel] = 0
    labels = np.unique(label_im)
    label_im = np.searchsorted(labels, label_im)
    return label_im


def calc_les_features(
    lesmask,
    featuretype="clusters",
    label_im=None,
    imval=None,
    sizethresh=0,
    featureset="all",
    affmat=None,
    imobj=None,
    save_data=False):
    returnvals = []
    if GlobalOpts.verbose:
        print("Calculating lesion features of type %s" % featuretype)
    if featuretype == "clusters" or label_im is None:
        label_im, nb_labels = ndimage.label(lesmask, structure=np.ones([3, 3, 3]))
        if save_data:
            pnii = fslimage.Image(label_im, header=GlobalOpts.imobj.header)
            pnii.save('DEBUG_LABEL')
        if sizethresh > 0:
            label_im = apply_size_thresh(label_im, lesmask, sizethresh)
        nb_labels = label_im.max()
        if featuretype == "clusters":
            return label_im
    if featuretype == "geometry":
        nb_labels = label_im.max()
        sizes = np.asarray(ndimage.sum(lesmask, label_im, range(nb_labels + 1)))
        nx, ny, nz = lesmask.shape
        returnvals = np.array([])
        # moment of inertia
        xx, yy, zz = generate_mni_coords(affmat, imobj)
        if featureset == "all" or "size" in featureset:
            nn = np.asarray(ndimage.sum(lesmask, label_im, range(1, label_im.max() + 1)))
            returnvals = h_stack([returnvals, make2d(nn)])
        if featureset == "all" or "coords" in featureset or "geom" in featureset:
            xxr = np.asarray(ndimage.mean(xx, label_im, range(1, label_im.max() + 1)))
            yyr = np.asarray(ndimage.mean(yy, label_im, range(1, label_im.max() + 1)))
            zzr = np.asarray(ndimage.mean(zz, label_im, range(1, label_im.max() + 1)))
            if featureset == "all" or "coords" in featureset:
                returnvals = h_stack([returnvals, make2d(xxr)])
                returnvals = h_stack([returnvals, make2d(yyr)])
                returnvals = h_stack([returnvals, make2d(zzr)])
        if featureset == "all" or "geom" in featureset:
            Imxx = np.asarray(
                ndimage.mean(xx * xx, label_im, range(1, label_im.max() + 1))
            )
            Imxy = np.asarray(
                ndimage.mean(xx * yy, label_im, range(1, label_im.max() + 1))
            )
            Imxz = np.asarray(
                ndimage.mean(xx * zz, label_im, range(1, label_im.max() + 1))
            )
            Imyy = np.asarray(
                ndimage.mean(yy * yy, label_im, range(1, label_im.max() + 1))
            )
            Imyz = np.asarray(
                ndimage.mean(yy * zz, label_im, range(1, label_im.max() + 1))
            )
            Imzz = np.asarray(
                ndimage.mean(zz * zz, label_im, range(1, label_im.max() + 1))
            )
            Im = np.zeros([3, 3])
            rvals = xxr * 0
            for n in range(len(Imxx)):
                Im[0, 0] = Imxx[n] - xxr[n] * xxr[n]
                Im[1, 1] = Imyy[n] - yyr[n] * yyr[n]
                Im[2, 2] = Imzz[n] - zzr[n] * zzr[n]
                Im[0, 1] = Imxy[n] - xxr[n] * yyr[n]
                Im[0, 2] = Imxz[n] - xxr[n] * zzr[n]
                Im[1, 2] = Imyz[n] - yyr[n] * zzr[n]
                Im[1, 0] = Im[0, 1]
                Im[2, 0] = Im[0, 2]
                Im[2, 1] = Im[1, 2]
                evals, evecs = np.linalg.eig(Im)
                evals = np.sort(np.real(evals))
                if evals[1] > 0:
                    MIratio = evals[2] / evals[1]  # FA-like quantity (in-plane)
                else:
                    MIratio = evals[2] * 100.0
                rvals[n] = MIratio
            returnvals = h_stack([returnvals, make2d(rvals)])
    if featuretype == "imageval":
        if imval is None:
            imval = []
        returnvals = np.array([])
        nx, ny, nz = lesmask.shape
        if featureset == "all" or "max" in featureset:
            rvals = np.asarray(
                ndimage.maximum(imval, label_im, range(1, label_im.max() + 1))
            )
            returnvals = h_stack([returnvals, make2d(rvals)])
        if featureset == "all" or "mean" in featureset:
            rvals = np.asarray(ndimage.mean(imval, label_im, range(1, label_im.max() + 1)))
            returnvals = h_stack([returnvals, make2d(rvals)])
        if featureset == "all" or "min" in featureset:
            rvals = np.asarray(
                ndimage.minimum(imval, label_im, range(1, label_im.max() + 1))
            )
            returnvals = h_stack([returnvals, make2d(rvals)])
    return returnvals


# call calc_les_features using the manual mask as imval, and use the max feature to distinguish the true label for the classification
#
# take a set of subjects
# for each subject use first stage probability outputs (already thresholded) + images to get feature matrix
# calculate a matrix for each subject (each row is a lesion, each column is a feature)
#   also use manual mask as the imval in order to get max() which gives true class
# concatenate matrices from all subjects
# run the big matrix through the classifier to train it
# run the query subject through the feature calculation function
# apply the classifier to the query subject to classify each lesion
# put probabilities back into the lesion voxels


def calc_features(imdata, pts, brainmask):
    imdata = normalise_data(imdata)
    thisdata = getvals(imdata, pts)
    # create local average (patch) data features if requested
    if len(GlobalOpts.patchsizes) > 0:
        made_mask = False
        for psize in GlobalOpts.patchsizes:
            psizevec = [psize, psize, psize]
            if not GlobalOpts.patch3D:
                psizevec[np.argmax(GlobalOpts.pixdims)] = 1
            if not made_mask:
                made_mask = True
                mask_sm = ndimage.uniform_filter(brainmask, size=psizevec)
                mask_sm = (
                    mask_sm + 1 - brainmask
                )  # set background to 1 to avoid divide by zeros
            imdata_sm = ndimage.uniform_filter(imdata, size=psizevec) / mask_sm
            imdata_sm = normalise_data(imdata_sm)
            newdata = getvals(imdata_sm, pts)
            thisdata = h_stack([thisdata, newdata])
    return thisdata


def calc_coord_features(affmat, pts, im_struct, use_coords=True):
    mnicoords = []
    if use_coords:
        vox2mni = vox2mni_mat(affmat, im_struct)
        # only uses subset for normalisation! (is this a problem?)
        mnicoords = (
            normalise_data(mni_coords(vox2mni, pts)) * GlobalOpts.spatial_weighting
        )
    return mnicoords


#######################################################################################

# Functions for dealing with files and filenames


def load_les_data(
    filenamearray,
    matfilelist=None,
    labelfilelist=None,
    lesprobfilelist=None,
    spatial_features=False,
    use_labels=True,
    feature2set="all",
    save_data=False):
    if matfilelist is None:
        matfilelist = ['']
    if labelfilelist is None:
        labelfilelist = ['']
    if lesprobfilelist is None:
        lesprobfilelist = ['']
    # for each training subject need to:
    #  - load first stage prob data and threshold
    #  - define lesions and give each a number
    #  - for each lesion calculate features
    #     - from thresholded prob: size, moment of inertia
    #     - from CSF dist: min dist to CSF  (needs set of dist values for all voxels in lesion)
    #     - from coords: symmetry (requires labels from all lesions)
    #     - etc.
    #  - also calculate overlap with manual mask (in 3D - via max within lesion) to determine if it is a true or false lesion
    # for a query subject need to do all of the above except the overlap with the manual mask
    # Output should be a single row of the feature matrix for each lesion
    #
    # Replace choose_pts with something that finds and numbers the lesions
    #  - return an image with labelled voxels rather than a list of points
    # return array has each row corresponding to an individual training sample (pt)

    alldatavals = np.array([])
    alllabelvals = np.array([])
    nonimagefiles = labelfilelist + matfilelist + lesprobfilelist
    # modality-based lists (each list contains a set of subjects for one modality)
    for filelist in filenamearray:  # each loop is a separate subject (for training)
        combdata = np.array([])
        # get lesion probability, threshold and define clusters
        if lesprobfilelist[0]:
            lesprobfilename = lesprobfilelist.pop(0)
            if GlobalOpts.verbose:
                print(" ... reading lesion probability %s" % lesprobfilename)
            imobj = fslimage.Image(lesprobfilename)
            if save_data:
                GlobalOpts.imobj = imobj
            lprobim = imobj.data
            lprobim = (lprobim > GlobalOpts.lesprobthresh).astype(np.float32)
            labelles = calc_les_features(
                lprobim,
                featuretype="clusters",
                sizethresh=GlobalOpts.sizethresh,
                save_data=save_data,
            )
        else:
            print("ERROR:: Could not find any lesion probability image")
            sys.exit(3)
        # read matrix mapping to MNI space
        flirtmat = np.identity(4)
        if matfilelist[0]:  # FLIRT matrix from native to MNI space
            matfile = matfilelist.pop(0)  # pick off first item from list
            if GlobalOpts.verbose:
                print(" ... reading matrix %s" % matfile)
            flirtmat = np.loadtxt(matfile)
        # extract label values from manual lesion mask
        if labelfilelist[0]:
            labelfilename = labelfilelist.pop(0)  # current label file
            if GlobalOpts.verbose:
                print(" ... reading manual label %s" % labelfilename)
            manlabim = (fslimage.Image(labelfilename).data > 0.5).astype(
                np.float32
            )  # force binarisation
            manlabvals = calc_les_features(
                lprobim,
                featuretype="imageval",
                featureset="max",
                label_im=labelles,
                imval=manlabim,
            )
        # calculate features from probability image
        combdata = calc_les_features(
            lprobim,
            featuretype="geometry",
            featureset=feature2set,
            label_im=labelles,
            affmat=flirtmat,
            imobj=imobj,
        )
        for filename in filelist:  # each file is a separate modality
            # only build training features from valid data files (not mat or labels)
            if filename not in nonimagefiles:
                if GlobalOpts.verbose:
                    print(" ... reading image %s" % filename)
                imdata = fslimage.Image(filename).data.astype(np.float32)
                thisdata = calc_les_features(
                    lprobim,
                    featuretype="imageval",
                    label_im=labelles,
                    imval=imdata,
                    featureset=feature2set,
                )
                combdata = h_stack([combdata, thisdata])
        # accumulate features and pts across different subjects
        alldatavals = v_stack([alldatavals, combdata])
        alllabelvals = v_stack([alllabelvals, manlabvals])
    return [alldatavals, labelles, alllabelvals]


def load_vox_data(
    filenamearray,
    matfilelist=None,
    labelfilelist=None,
    numpts=1,
    nonlespts=1,
    mode=None,
    spatial_features=False,
    use_labels=True):

    if matfilelist is None:
        matfilelist = ['']
    if labelfilelist is None:
        labelfilelist = ['']
    if mode is None:
        mode = ["equalpoints", "any"]
    # return array has each row corresponding to an individual training sample (pt)
    alldatavals = np.array([])
    allpts = np.array([])
    alllabelvals = np.array([])
    nonimagefiles = labelfilelist + matfilelist
    # modality-based lists (each list contains a set of subjects for one modality)
    for filelist in filenamearray:  # each loop is a separate subject (for training)
        combdata = np.array([])
        pts = np.array([])
        labelvals = np.array([])
        labelim = np.array([])
        lesmask = np.array([])
        # extract label values from lesion/label mask
        if labelfilelist[0]:
            labelfilename = labelfilelist.pop(0)  # current label file
            if GlobalOpts.verbose:
                print(" ... reading label %s" % labelfilename)
            try:
                labelim = fslimage.Image(labelfilename).data.astype(np.float32)
                labelim = (labelim > 0.5).astype(np.float32)
                if use_labels:
                    lesmask = labelim
            except Exception:
                if GlobalOpts.verbose:
                    print(" ... no label found - ignoring")
        # calculate common pts across all modalities (load brain mask using maskfeaturenum - it needs to have been remapped if subset of features used)
        imobj = fslimage.Image(filelist[GlobalOpts.maskfeaturenum])
        GlobalOpts.pixdims = imobj.pixdim
        brainmask = (imobj.data > 0).astype(np.float32)
        pts = choose_pts(
            brainmask, lesmask=lesmask, numpts=numpts, nonlespts=nonlespts, mode=mode
        )
        if np.sum(labelim.shape) > 0:
            labelvals = getvals(
                labelim, pts
            )  # has to be done here, as need pts and brainmask
        for filename in filelist:  # each file is a separate modality
            # only build training features from valid data files (not mat or labels)
            if filename not in nonimagefiles:
                if GlobalOpts.verbose:
                    print(" ... reading image %s" % filename)
                imdata = fslimage.Image(filename).data.astype(np.float32)
                thisdata = calc_features(imdata, pts, brainmask)
                combdata = h_stack([combdata, thisdata])
        # read matrix mapping to MNI space
        if matfilelist[0]:  # FLIRT matrix from native to MNI space
            matfile = matfilelist.pop(0)  # pick off first item from list
            if GlobalOpts.verbose:
                print(" ... reading matrix %s" % matfile)
            affmat = np.loadtxt(matfile)
            mnicoords = calc_coord_features(
                affmat, pts, imobj, use_coords=spatial_features
            )
            combdata = h_stack([make2d(combdata), mnicoords])
        # accumulate features and pts across different subjects
        alldatavals = v_stack([alldatavals, combdata])
        allpts = v_stack([allpts, pts])
        alllabelvals = v_stack([alllabelvals, labelvals])
    return [alldatavals, allpts, alllabelvals]


def readnonblanklines(filename):
    return [
        x for x in open(filename, encoding="utf-8-sig").read().splitlines() if x.strip()
    ]  # removes blank lines


def makefilearray(datafiles, filelistmode="modality_based"):
    # Supports two ways of reading lists: subject-based and modality-based
    #   subject-based lists (each list contains a set of modalities for one subject)
    #   modality-based lists (each list contains a set of subjects for one modality)
    # for modality-based, unpack initially st each top level item is a modality (fix later)
    filenamearray0 = []
    if filelistmode == "singlefile":
        if len(datafiles) != 1:
            print("\nERROR:: more than one file specified in singlefile mode")
            sys.exit(2)
        delimiter = None
        if GlobalOpts.delimiter:
            delimiter = GlobalOpts.delimiter
        filenamearray = [x.split(delimiter) for x in readnonblanklines(datafiles[0])]
        if GlobalOpts.transposefile:
            filenamearray = zip(*filenamearray)
        return filenamearray
    else:
        for fn in datafiles:
            newfiles = readnonblanklines(fn)
            filenamearray0.append(newfiles)
        if not all([len(fn) == len(filenamearray0[0]) for fn in filenamearray0]):
            print("\nERROR:: unequal number of files specified")
            print("Specified files are: %s" % filenamearray0)
            sys.exit(1)
        if filelistmode == "modality_based":
            # now fix it st top level items are grouped by subject (not modality) - i.e. transpose
            filenamearray = zip(*filenamearray0)
        else:
            filenamearray = filenamearray0
        return filenamearray


#######################################################################################


def Usage(exit_status=None):
    # print("\nUsage: %s -l <labelfilelist> -q <queryfilelist> [-m <matfilelist>] [--querymat=<matfile>] [-v] [--nn=<number>] [--trainingpts=<number>] [--spatialweight=<value>] [--classifymethod=<methodname>] [--featuresubset=<num>,<num>,...] <modality1_filelist> <modality2_filelist> ..." % sys.argv[0])
    # print("\nUsage: %s -l <labelfilelist> -q <queryfilelist> [-m <matfilelist>] [--querymat=<matfile>] [-v] [--nn=<number>] [--trainingpts=<number>] [--spatialweight=<value>] [--classifymethod=<methodname>] [--featuresubset=<num>,<num>,...] --listbysubject <subject1_filelist> <subject2_filelist> ..." % sys.argv[0])
    #  Note that all options must go before the modality filelists
    #  All filelists should contain a set of filenames (for each subject, one per row)
    #  The exception is the queryfilelist which should contain a list of individual modality files needed to construct the features
    print(
        "\nUsage: %s --singlefile=<masterlistfile> --labelfeaturenum=<num> --brainmaskfeaturenum=<num> --querysubjectnum=<num> [options]"
        % sys.argv[0]
    )
    print(tw.dedent("""
    Compulsory arguments:
     * --singlefile=<masterlistfile>    name of the master file
     * --querysubjectnum=<num>          row number of query subject (in masterlistfile)
     * --brainmaskfeaturenum=<num>      column number (in the master file) of images to derive non-zero mask from.
     * Training dataset specification:
       If the training subjects to use are listed in the master file, the following two arguments need to be specified:
           --labelfeaturenum=<num>          column number (in the master file) of the manual masks (or any placehold name for query subjects)
           and
           --trainingnums=<val>             subjects to be used in training. List of row numbers (comma separated, no spaces) or "all" to use all the subjects in the master file.
       Alternatively load from file (previously saved with --saveclassifierdata, see below):
           --loadclassifierdata=<name>      load training data from file

    Optional arguments:
      -o <outname>                     specify (base) output name of files (default: output_bianca)
      --featuresubset=<num>,<num>,...  set of column numbers (comma separated and no spaces) for features/images to use (default: use all available modalities as intensity features). The image used to derive non-zero mask from must be part of the features subset.
      --matfeaturenum=<num>            column number of matrix files (in masterlistfile). Needed to extract spatial features (MNI coordinates)
      --spatialweight=<value>          weighting for spatial coordinates (default = 1, i.e. variance-normalised MNI coordinates). Requires --matfeaturenum to be specified.
      --patchsizes=<num>,<num>,...     list of patch sizes for local averaging
      --patch3D                        use 3D patches (default is 2D)
      --selectpts=<val>                "any" (default) or "surround" or "noborder"
      --trainingpts=<val>              number (max) of (lesion) points to use (per training subject) or "equalpoints" to selec all lesion points and equal number of non-lesion points
      --nonlespts=<val>                number (max) of non-lesion points to use. If not specified will be set to the same amount of lesion points.
      --saveclassifierdata=<name>      save training data to file
      --seed=<seed>                    Seed for random number generator
      -v                               use verbose mode

    Notes:
       (i) The masterlistfile should contain a row per subject (training or testing) and on each row a list of all files needed for that subject (image data [and matrix transform])
       (ii) --labelfeaturenum specified the column that contains the manual mask for the training subjects, counting from 1 and using the ordering that is used in the masterlistfile. For query subjects, use any placehold name in the masterlistfile, to keep the same column order of the training subjects.
       (iii) the featuresubset should be a list of numbers (comma separated and no spaces) that specify the modalities to be used, counting from 1 and using the ordering that is used in the masterlistfile.  If this is not specified the default is to use all modalities as features.
       (iv) either the --trainingnums=<val> or the  --loadclassifierdata=<name> *must* be used.

     An example call is:
       ./bianca --singlefile=masterfilelist.txt --labelfeaturenum=3 --brainmaskfeaturenum=1 --querysubjectnum=2 --trainingnums=1,2,3,4,5,6,7,8,9,10, --featuresubset=1,2 --matfeaturenum=4 --trainingpts=2000 --nonlespts=10000 --selectpts=noborder -o sub2_bianca -v

     An example script for creating a master list file is:
       for fn in Pz* ; do
          mods="";
          for gg in $fn/[A-Z]*.nii.gz $fn/*.mat ; do
            mods="$mods $gg" ;
          done ;
          echo $mods ;
       done
    """))

    if exit_status is not None:
        sys.exit(exit_status)


#######################################################################################


def p_main():
    fsldir = os.environ['FSLDIR']
    # set defaults
    GlobalOpts.verbose = False
    GlobalOpts.debug = False
    GlobalOpts.classify_method = "knn"
    GlobalOpts.spatial_weighting = -1
    GlobalOpts.numNN = 40
    GlobalOpts.trainingpts = 2000
    GlobalOpts.nonlespts = 0
    GlobalOpts.trnmode = []
    GlobalOpts.lesprobthresh = 0.99
    GlobalOpts.sizethresh = 1
    GlobalOpts.spatial_features = False
    GlobalOpts.surround = False
    GlobalOpts.load_classifier = ""
    GlobalOpts.save_classifier = ""
    GlobalOpts.feature_subset = []
    GlobalOpts.feature2set = "all"
    GlobalOpts.filelistmode = "modality_based"
    GlobalOpts.singlefile = ""
    GlobalOpts.querysubjnum = -1
    GlobalOpts.trainingnums = []
    GlobalOpts.matfeaturenum = -1
    GlobalOpts.labelfeaturenum = -1
    GlobalOpts.lesprobfeaturenum = -1
    GlobalOpts.maskfeaturenum = -1
    GlobalOpts.delimiter = None
    GlobalOpts.transposefile = False
    GlobalOpts.patchsizes = []
    GlobalOpts.patch3D = False
    GlobalOpts.lesionmode = False
    GlobalOpts.seed = 1
    GlobalOpts.pixdims = []
    GlobalOpts.mnisform = fslimage.Image(
        f'{fsldir}/data/standard/MNI152_T1_1mm').getAffine('voxel', 'world')
    matfilelist = [""]
    queryfiles = ""
    outbasename = ""
    matfiles = ""
    labelfiles = ""
    querymat = ""

    ############# PROCESS COMMAND LINE ARGUMENTS #############

    if len(sys.argv) == 1:
        Usage(0)

    try:
        opts, args = getopt.getopt(
            sys.argv[1:],
            "vhm:l:o:q:",
            [
                "labelfilelist=",
                "matfilelist=",
                "queryfilelist=",
                "out=",
                "nn=",
                "trainingpts=",
                "selectpts=",
                "spatialweight=",
                "classifymethod=",
                "help",
                "querymat=",
                "loadclassifierdata=",
                "saveclassifierdata=",
                "featuresubset=",
                "listbysubject",
                "querysubjectnum=",
                "trainingnums=",
                "nonlespts=",
                "singlefile=",
                "matfeaturenum=",
                "lesprobfeaturenum=",
                "labelfeaturenum=",
                "brainmaskfeaturenum=",
                "delimiter=",
                "pthresh=",
                "sizethresh=",
                "transposefile",
                "patchsizes=",
                "patch3D",
                "feature2set=",
                "debug",
                "seed="
            ],
        )
    except getopt.GetoptError:
        print("\nERROR:: unrecognised option[s]\n")
        Usage(2)
    for opt, arg in opts:
        if opt in ("-h", "--help"):
            Usage(0)
        elif opt == "-v":
            GlobalOpts.verbose = True
        elif opt == "--debug":
            GlobalOpts.debug = True
        elif opt in ("-l", "--labelfilelist"):
            labelfiles = arg
        elif opt in ("-m", "--matfilelist"):
            matfiles = arg
            GlobalOpts.spatial_features = True
        elif opt in ("-q", "--queryfilelist"):
            queryfiles = arg
        elif opt in ("--querymat"):
            querymat = arg
        elif opt == "--nn":
            GlobalOpts.numNN = int(arg)
        elif opt == "--nonlespts":
            GlobalOpts.nonlespts = int(arg)
        elif opt == "--trainingpts":
            if arg == "equalpoints":
                GlobalOpts.trainingpts = "all"
                GlobalOpts.trnmode += ["equalpoints"]
            else:
                GlobalOpts.trainingpts = int(arg)
                GlobalOpts.trnmode += ["npoints"]
        elif opt == "--selectpts":
            if arg == "surround":
                GlobalOpts.surround = True
                GlobalOpts.trnmode += ["surround"]
            elif arg == "noborder":
                GlobalOpts.surround = False
                GlobalOpts.trnmode += ["noborder"]
            else:
                GlobalOpts.trnmode += ["any"]
        elif opt == "--spatialweight":
            GlobalOpts.spatial_weighting = float(arg)
            GlobalOpts.spatial_features = True
        elif opt == "--classifymethod":
            GlobalOpts.classify_method = arg
        elif opt == "--loadclassifierdata":
            GlobalOpts.load_classifier = arg
        elif opt == "--saveclassifierdata":
            GlobalOpts.save_classifier = arg
        elif opt == "--delimiter":
            GlobalOpts.delimiter = arg
        elif opt == "--listbysubject":
            GlobalOpts.filelistmode = "modality_based"
        elif opt == "--singlefile":
            GlobalOpts.filelistmode = "singlefile"
            GlobalOpts.singlefile = arg
        elif opt == "--transposefile":
            GlobalOpts.transposefile = True
        elif opt == "--patch3D":
            GlobalOpts.patch3D = True
        elif opt == "--featuresubset":
            GlobalOpts.feature_subset = [
                int(x) - 1 for x in arg.split(",")
            ]  # users count from 1
        elif opt == "--feature2set":
            GlobalOpts.feature2set = [x for x in arg.split(",")]
        elif opt == "--querysubjectnum":
            GlobalOpts.querysubjnum = int(arg) - 1  # let users count from 1, not 0
        elif opt == "--matfeaturenum":
            GlobalOpts.matfeaturenum = int(arg) - 1  # let users count from 1, not 0
            GlobalOpts.spatial_features = True
        elif opt == "--lesprobfeaturenum":
            GlobalOpts.lesprobfeaturenum = int(arg) - 1  # let users count from 1, not 0
            GlobalOpts.lesionmode = True
        elif opt == "--labelfeaturenum":
            GlobalOpts.labelfeaturenum = int(arg) - 1  # let users count from 1, not 0
        elif opt == "--brainmaskfeaturenum":
            GlobalOpts.maskfeaturenum = int(arg) - 1  # let users count from 1, not 0
        elif opt == "--trainingnums":
            if arg == "all":
                GlobalOpts.trainingnums = "all"
            else:
                GlobalOpts.trainingnums = [
                    int(x) - 1 for x in arg.split(",")
                ]  # users count from 1
        elif opt == "--pthresh":
            GlobalOpts.lesprobthresh = float(arg)  #
        elif opt == "--sizethresh":
            GlobalOpts.sizethresh = int(arg)  #
        elif opt == "--patchsizes":
            GlobalOpts.patchsizes = [float(x) for x in arg.split(",")]  #
        elif opt in ("-o", "--out"):
            outbasename = arg
        elif opt == '--seed':
            GlobalOpts.seed = int(arg)

    # Some sanity checking on compulsory arguments
    # Master file
    if not GlobalOpts.singlefile:
        print("\nERROR: no master file specified ")
        sys.exit(1)
    # Query subject
    if GlobalOpts.querysubjnum == -1 and not queryfiles:
        print("\nERROR: no query subject specified")
        sys.exit(1)
    if GlobalOpts.querysubjnum >= 0 and queryfiles:
        print(
            "\nERROR: cannot specify both a query subject number *and* a query file list"
        )
        sys.exit(1)
    # Image for brain mask. Needed both for training (if not loading file) to select points and for query subject to limit search for lesions
    if GlobalOpts.maskfeaturenum == -1:
        print(
            "\nERROR: mask image not specified (compulsory argument --brainmaskfeaturenum missing) "
        )
        sys.exit(1)

    # Training dataset must be specified either by loading a file or specifying both training subject rows and label column
    if GlobalOpts.load_classifier:
        if GlobalOpts.labelfeaturenum >= 0 or len(GlobalOpts.trainingnums) > 0:
            print(
                "\nERROR: cannot specify both training subjects/labels *and* training file "
            )
            sys.exit(1)
    else:
        if GlobalOpts.labelfeaturenum == -1 and len(GlobalOpts.trainingnums) == 0:
            print("\nERROR: no training dataset specified ")
            sys.exit(1)
        if GlobalOpts.labelfeaturenum == -1 and len(GlobalOpts.trainingnums) > 0:
            print(
                "\nERROR: no training dataset specified (missing reference to manual masks: --labelfeaturenum ) "
            )
            sys.exit(1)
        if GlobalOpts.labelfeaturenum >= 0 and len(GlobalOpts.trainingnums) == 0:
            print(
                "\nERROR: no training dataset specified (missing reference to training subjects: --trainingnums ) "
            )
            sys.exit(1)

    # If 3D patch specified without patchsize gives error
    if len(GlobalOpts.patchsizes) == 0 and GlobalOpts.patch3D:
        print("\nERROR: patchsize(s) must be specified to enable patch 3D option ")
        sys.exit(1)

    # fill in some defaults
    if GlobalOpts.nonlespts <= 0:
        GlobalOpts.nonlespts = GlobalOpts.trainingpts

    # read in files that contain lists of files that point to pre-processed data
    datafiles = list(args)
    if GlobalOpts.filelistmode == "singlefile":
        datafiles = [GlobalOpts.singlefile]
    if labelfiles:
        labelfilelist = readnonblanklines(labelfiles)
    if queryfiles:
        queryfilelist = readnonblanklines(queryfiles)
    if matfiles and GlobalOpts.spatial_features:
        matfilelist = readnonblanklines(matfiles)
        # if matfile provided and no spatial weighting specified, set sw to 1 by default
        if GlobalOpts.spatial_weighting == -1:
            GlobalOpts.spatial_weighting = 1

    verbose = GlobalOpts.verbose
    outbasename = fslimage.removeExt(outbasename)
    if not outbasename:
        if queryfiles:
            outbasename = fslimage.removeExt(queryfilelist[0]) + "_bianca"
        else:
            outbasename = "output_bianca"
    # file array is stored so that top level items are grouped by subject
    filenamearray = makefilearray(datafiles, filelistmode=GlobalOpts.filelistmode)
    nmodes = len(filenamearray[0])
    nsubjs = len(filenamearray)
    if GlobalOpts.trainingnums == "all":
        GlobalOpts.trainingnums = list(range(0, nsubjs))
    # Some sanity checking
    if GlobalOpts.filelistmode != "singlefile":
        if len(labelfilelist) != nsubjs or (
            GlobalOpts.spatial_features and len(matfilelist) != nsubjs
        ):
            print(
                "\nERROR:: inconsistent lengths of label file list, mat file list and data file list"
            )
            sys.exit(1)

    if verbose:
        print(
            "Number of modalities = %d , number of possible training subjects = %d"
            % (nmodes, nsubjs)
        )
        print(
            "Files are: label file list = %s, data file list = %s"
            % (labelfiles, datafiles)
        )
        print(
            "Number of training points = %s, mode = %s"
            % (GlobalOpts.trainingpts, GlobalOpts.trnmode)
        )

    # process options for input files (multiple ways of specifying things need to be handled)
    if GlobalOpts.querysubjnum >= 0:
        if GlobalOpts.querysubjnum in GlobalOpts.trainingnums:
            GlobalOpts.trainingnums.remove(GlobalOpts.querysubjnum)
    if GlobalOpts.matfeaturenum >= 0:
        if matfiles:
            print(
                "\nERROR:: cannot specify both a matrix feature number *and* a matrix file list"
            )
            sys.exit(1)
        else:
            matfilelist = [x[GlobalOpts.matfeaturenum] for x in filenamearray]
            # if matfile provided and no spatial weighting specified, set to 1 by default
            if GlobalOpts.spatial_weighting == -1:
                GlobalOpts.spatial_weighting = 1
    if GlobalOpts.labelfeaturenum >= 0:
        labelfilelist = [x[GlobalOpts.labelfeaturenum] for x in filenamearray]
        # remove subjects (rows) that aren't needed
        traininglabellist = sublist(labelfilelist, GlobalOpts.trainingnums)
    if GlobalOpts.lesprobfeaturenum >= 0:
        lesprobfilelist = [x[GlobalOpts.lesprobfeaturenum] for x in filenamearray]
    # remove subjects (rows) that aren't needed
    trainingfilearray = sublist(filenamearray, GlobalOpts.trainingnums)
    trainingmatlist = sublist(matfilelist, GlobalOpts.trainingnums)

    if GlobalOpts.lesionmode:
        traininglesproblist = sublist(lesprobfilelist, GlobalOpts.trainingnums)
    if GlobalOpts.feature_subset:
        # remove features (cols) that aren't needed
        trainingfilearray = [
            sublist(x, GlobalOpts.feature_subset) for x in trainingfilearray
        ]
        if verbose:
            print("Filenames = %s" % trainingfilearray)

    # Some sanity checking
    # If no mat files specified set spatialweighting to 0, but if it was set as greater than 0, it gives error
    if not matfiles and GlobalOpts.matfeaturenum == -1:
        if GlobalOpts.spatial_weighting == -1:
            GlobalOpts.spatial_weighting = 0
        if verbose:
            print("no spatial features used")
        if GlobalOpts.spatial_weighting > 0:
            print(
                "\nERROR: matrix files must be specified to use spatial weighting option "
            )
            sys.exit(1)

    # query inputs
    if GlobalOpts.querysubjnum >= 0:
        queryfilelist = filenamearray[GlobalOpts.querysubjnum]
        if GlobalOpts.matfeaturenum >= 0:
            querymat = filenamearray[GlobalOpts.querysubjnum][GlobalOpts.matfeaturenum]
        else:
            querymat = sublist(matfilelist, GlobalOpts.querysubjnum)[0]
    querylabel = [""]
    if GlobalOpts.labelfeaturenum >= 0:
        querylabel = [x[GlobalOpts.labelfeaturenum] for x in [queryfilelist]]
    if GlobalOpts.lesprobfeaturenum >= 0:
        querylesprob = [x[GlobalOpts.lesprobfeaturenum] for x in [queryfilelist]]
    if GlobalOpts.feature_subset:
        # remove features (cols) that aren't needed
        queryfilelist = sublist(queryfilelist, GlobalOpts.feature_subset)

    # convert feature_subset numbers into numbers referring to columns in the data arrays
    #    as labels and mats do not generate columns in the data arrays
    new_feature_subset = []
    if GlobalOpts.feature_subset:
        new_feature_flags = [x in GlobalOpts.feature_subset for x in range(nmodes)]
        if GlobalOpts.matfeaturenum >= 0:
            new_feature_flags[GlobalOpts.matfeaturenum] = False
        if GlobalOpts.labelfeaturenum >= 0:
            new_feature_flags[GlobalOpts.labelfeaturenum] = False
        if GlobalOpts.lesprobfeaturenum >= 0:
            new_feature_flags[GlobalOpts.lesprobfeaturenum] = False
        new_feature_subset = np.where(new_feature_flags)[0]
        # remap maskfeaturenum to new feature numbers
        if GlobalOpts.maskfeaturenum in new_feature_subset:
            GlobalOpts.maskfeaturenum = np.where(
                new_feature_subset == GlobalOpts.maskfeaturenum
            )[0][0]
        else:
            print("\nERROR:: specified mask images must be part of the feature list")
            sys.exit(3)
    else:
        if verbose:
            print(
                "Using all available images as intensity features (no intensity features subset specified)"
            )

    np.random.seed(GlobalOpts.seed)

    ############# NOW DO THE REAL WORK #############

    # load training data

    if not GlobalOpts.load_classifier:
        if verbose:
            print("Loading training data")
        if GlobalOpts.lesionmode:
            [training_dat, _, labelvals] = load_les_data(
                trainingfilearray,
                matfilelist=trainingmatlist,
                labelfilelist=traininglabellist,
                lesprobfilelist=traininglesproblist,
                spatial_features=GlobalOpts.spatial_features,
                feature2set=GlobalOpts.feature2set,
            )
        else:
            [training_dat, training_pts, labelvals] = load_vox_data(
                trainingfilearray,
                matfilelist=trainingmatlist,
                labelfilelist=traininglabellist,
                numpts=GlobalOpts.trainingpts,
                nonlespts=GlobalOpts.nonlespts,
                mode=GlobalOpts.trnmode,
                spatial_features=GlobalOpts.spatial_features,
            )
        if verbose:
            print("Training data is size %s" % (training_dat.shape,))

    if GlobalOpts.load_classifier:
        print("Loading training data  %s" % (GlobalOpts.load_classifier))
        # load the training data for the classifier
        training_dat = pickle.load(open(GlobalOpts.load_classifier, "rb"))
        labelvals = pickle.load(open(GlobalOpts.load_classifier + "_labels", "rb"))
        if verbose:
            print("Training data is size %s" % (training_dat.shape,))
    else:
        # train classifier
        if GlobalOpts.save_classifier:
            # save training data for classifier (if requested)
            pickle.dump(training_dat, open(GlobalOpts.save_classifier, "wb"))
            pickle.dump(labelvals, open(GlobalOpts.save_classifier + "_labels", "wb"))
            if verbose:
                print("Saving training data of size %s" % (training_dat.shape,))
                print("Saving training data as  %s" % (GlobalOpts.save_classifier))

    # load query data
    if verbose:
        print("Loading query data")
    if verbose:
        print("  Query file list = %s" % (queryfilelist,))
    if GlobalOpts.lesionmode:
        [check_dat, cluster_im, _] = load_les_data(
            [queryfilelist],
            matfilelist=[querymat],
            labelfilelist=querylabel,
            lesprobfilelist=querylesprob,
            spatial_features=GlobalOpts.spatial_features,
            feature2set=GlobalOpts.feature2set,
            save_data=GlobalOpts.debug,
        )
    else:
        [check_dat, querypts, _] = load_vox_data(
            [queryfilelist],
            matfilelist=[querymat],
            labelfilelist=querylabel,
            mode=["allpoints"],
            spatial_features=GlobalOpts.spatial_features,
            use_labels=False,
        )

    if verbose:
        print("Training classifier")
    k_tree = NearestNeighbors(n_neighbors=GlobalOpts.numNN, algorithm="kd_tree").fit(
        training_dat
    )

    # look up values (each pt is a row, neighbours in different columns)
    if verbose:
        print("Applying classifier")
    [idvals, distvals] = call_knn(k_tree, check_dat)
    labels = labelvals[idvals]

    # apply kernel density estimation (generalisation of kNN)
    if verbose:
        print("Calculating p-values")
    if GlobalOpts.classify_method == "knn":
        pvals = calc_pvals(labels, distvals)
    elif GlobalOpts.classify_method == "kernel":

        def kernelfunc(dist):
            sigma = 100
            return np.exp(-dist * dist / (sigma * sigma))

        pvals = calc_pvals(labels, distvals, method="kernel", kernel=kernelfunc)

    # map pvalues back into the volume
    if verbose:
        print("Saving output image")
    queryimcl = fslimage.Image(queryfilelist[0])
    if GlobalOpts.lesionmode:
        pvalim = put_les_vals(cluster_im, pvals)
    else:
        pvalim = put_vox_vals(querypts, pvals, refim=queryimcl.data)
    pnii = fslimage.Image(pvalim, header=queryimcl.header)
    pnii.save(outbasename)


#######################################################################################

# Call the main function
if __name__ == '__main__':
    p_main()
