#!/usr/bin/env python

import subprocess   as sp
import os.path      as op
import itertools    as it
import                 shlex
import                 sys
import                 traceback

import numpy        as np
import scipy.signal as spsig
import nibabel      as nib

from fsl.utils.tempdir import tempdir
from fsl.data.image    import Image, addExt


def sprun(cmd):
    print(f'RUN {cmd}')
    sp.run(shlex.split(cmd), check=True)


# Data for testing various aspects of the NEWIMAGE::complexvolume
# radio=[True|False] -> whether or not
# to add a flip on the affine X axis
def create_test_complexvolume_data(seed=1):

    np.random.seed(seed)

    rneuro = np.random.randint(0, 100, (10, 10, 10)).astype(np.float32)
    ineuro = np.random.randint(0, 100, (10, 10, 10)).astype(np.float32)
    rradio = np.flip(rneuro, 0)
    iradio = np.flip(ineuro, 0)

    # test_complexvolume.cc has these
    # same values hard-coded, so if
    # you change one, you must change
    # the other.
    affneuro        = np.diag([ 3, 3, 3, 1])
    affradio        = np.diag([-3, 3, 3, 1])
    affneuro[:3, 3] = [10, 20, 30]
    affradio[:3, 3] = [37, 20, 30]

    cneuro = rneuro + ineuro * 1j
    cradio = rradio + iradio * 1j

    cneuro = nib.Nifti1Image(cneuro, affneuro)
    cradio = nib.Nifti1Image(cradio, affradio)

    cneuro.set_qform(*cneuro.get_sform(coded=True))
    cradio.set_qform(*cradio.get_sform(coded=True))

    cneuro.to_filename(f'test_complexvolume_neuro.nii.gz')
    cradio.to_filename(f'test_complexvolume_radio.nii.gz')


# Data for test NEWIMAGE::convolve  and the masked overload.
# We compare outputs against scipy.signal.convolve
def create_test_convolve_data():
    kernel  = np.ones((5, 5, 5))
    source  = np.zeros((100, 100, 100))
    mask    = np.zeros((100, 100, 100), dtype=bool)

    for x, y, z in it.product(range(100), range(100), range(100)):
        idx             = (x + y + z) / np.pi
        val             = np.sin(idx)
        source[x, y, z] = val

        # select every second maxima
        mask[  x, y, z] = (idx // np.pi) % 4 == 0

    data      = (source > 0).astype(np.int32)
    benchmark = spsig.convolve(data, kernel, mode='same', method='direct')

    masked_benchmark        = np.array(benchmark)
    masked_benchmark[~mask] = data[~mask]
    mask                    = mask.astype(np.int32)

    Image(kernel)          .save('test_convolve_kernel.nii.gz')
    Image(data)            .save('test_convolve_data.nii.gz')
    Image(source)          .save('test_convolve_source.nii.gz')
    Image(mask)            .save('test_convolve_mask.nii.gz')
    Image(benchmark)       .save('test_convolve_benchmark.nii.gz')
    Image(masked_benchmark).save('test_convolve_masked_benchmark.nii.gz')


def create_test_fileformats_data():

    data     = np.zeros((20, 20, 20))
    niivers  = [1, 2]
    suffixes = ['img', 'img.gz', 'img.bz2', 'img.zst',
                'nii', 'nii.gz', 'nii.bz2', 'nii.zst']

    for niiver, suffix in it.product(niivers, suffixes):
        fid = f'{niiver}_{suffix}'.replace('.', '_')
        Image(data, version=niiver).save(f'test_fileformats_{fid}.{suffix}')


if __name__ == '__main__':
    sprun('make clean')
    sprun('make')
    create_test_complexvolume_data()
    create_test_convolve_data()
    create_test_fileformats_data()
    sprun('./test-newimage -l test_suite')
