"""
This is a demonstration file to explain how to
use an image filter.

This is an implementation of the Gaussian mixture models, from the scikit-learn package.

#. Import a dataset;
#. Go to *Tools/Image Processing Toolbox* to start the Image Processing;
#. In the combo box of operations, go to the section *Demos* to select *DemoImageFilterGMM*;
#. In the options of the filter, in the field *Class count*,
   select the number of classes in which to classify (usually 2, 3 or 4);
#. Press the button *Compute Selected Preview* to execute the filter on the dataset.

:author: ORS Team
:contact: http://theobjects.com
:email: info@theobjects.com
:organization: Object Research Systems (ORS), Inc.
:address: 760 St-Paul West, suite 101, Montréal, Québec, Canada, H3C 1M4
:copyright: Object Research Systems (ORS), Inc. All rights reserved 2020.
:date: Oct 03 2017 08:55
:dragonflyVersion: 3.1.0.307 (D)
:UUID: 2c35d4f8a83a11e781be448a5b5d70c0
"""

__version__ = '1.0.0'

import numpy as np
from sklearn.mixture import GaussianMixture

from ORSModel import orsObj, createNumpyArrayFromChannel
from COMWrapper.ORS_def import CxvChannel_Data_Type
from OrsPythonPlugins.OrsSimpleFilters.filters.filterabstract import FilterAbstract
from OrsPythonPlugins.OrsSimpleFilters.filters.filterutil import FilterUtil


class DemoImageFilterGMM_2c35d4f8a83a11e781be448a5b5d70c0(FilterAbstract):

    @classmethod
    def getFilterName(cls):
        """
        Method called to get the full name of the filter.
        This name should be unique among all filters.
        :return: string
            Full name of the filter
        """
        return 'DemoImageFilterGMM'

    @classmethod
    def getAbbreviatedOutputName(cls, outputIndex):
        """
        Method called to get an abbreviated name for an output.
        By default, returns the full name of the filter with the index.
        :return: string
            Abbreviated name of the filter
        """
        if outputIndex == 0:
            return 'DemoFiltGMM'
        else:
            return 'DemoFiltGMM/' + str(outputIndex)
    
    @classmethod
    def getFilterCategory(cls):
        """
        Method called to get a category for the filter.
        :return: string
            Category of the filter (ex: Smoothing, Edge detection, ...)
        """
        return 'Demo'
    
    @classmethod
    def getFilterInfo(cls):
        """
        Method called to get an information string about the filter.
        This string is displayed in the UI.
        This string could serve as a general description or to display a warning.
        :return: str
        """
        return 'This is a demonstration filter.'
    
    @classmethod
    def getClassCount(cls, booleanArguments, numericArguments, stringArguments):
        """
        Method called to know if this filter produces an output with a defined count of classes.
        :return: int
            Count of classes
        """
        return numericArguments['classCount']
    
    @classmethod
    def getInputCount(cls):
        """
        Method called to know how many inputs (datasets) are required.
        :return: int
        """
        return 1
    
    @classmethod
    def getOutputCount(cls):
        """
        Method called to know how many outputs (datasets) are required.
        :return: int
        """
        return 1

    @classmethod
    def _apply(cls, xMin, yMin, zMin, tMin, xMax, yMax, zMax, tMax,
               listInputChannelId, listOutputChannelId, listIndexFirstVoxelInput, listIndexFirstVoxelOutput,
               progressId, numpyKernel, dictBooleanArguments, dictNumericArguments, dictStringArguments, zOutputSize=1):
        """
        This method is overloaded because the computation has to be done on full slices.
        It is essentially copied from the parent class FilterAbstract.
        """

        dataInputGUID = listInputChannelId[0]
        dataOutputGUID = listOutputChannelId[0]
        dataInput = createNumpyArrayFromChannel(dataInputGUID)
        dataOutput = createNumpyArrayFromChannel(dataOutputGUID)

        xInput0, yInput0, zInput0, tInput0 = listIndexFirstVoxelInput[0]
        xOutput0, yOutput0, zOutput0, tOutput0 = listIndexFirstVoxelOutput[0]

        xSize = xMax - xMin + 1
        ySize = yMax - yMin + 1
        zSize = zMax - zMin + 1

        isOnFullSpatialChannel = xSize == dataInput.getXSize() and \
                                 ySize == dataInput.getYSize() and \
                                 zSize == dataInput.getZSize() and \
                                 xSize == dataOutput.getXSize() and \
                                 ySize == dataOutput.getYSize() and \
                                 zSize == dataOutput.getZSize()

        xMinComputed = xInput0  # All data in X is required
        yMinComputed = yInput0  # All data in Y is required
        zMinComputed = max(zMin, zInput0)  # Only the current Z is required

        xMaxComputed = xInput0 + dataInput.getXSize() - 1  # All data in X is required
        yMaxComputed = yInput0 + dataInput.getYSize() - 1  # All data in Y is required
        zMaxComputed = min(zMax, zInput0 + dataInput.getZSize() - 1)  # Only the current Z is required

        IProgress = orsObj(progressId)

        for t in range(tMin, tMax + 1):
            if IProgress is not None and IProgress.getIsCancelled():
                return

            tInput = t - tInput0
            tOutput = t - tOutput0

            dataInputTSlice = dataInput[tInput]
            dataOutputTSlice = dataOutput[tOutput]

            if isOnFullSpatialChannel:
                cls.applyOnPatch(dataInputTSlice, dataOutputTSlice, isOnFullSpatialChannel, dictBooleanArguments,
                                 dictNumericArguments, dictStringArguments, progressId)
            else:
                dataInputTSliceComputationArea = dataInputTSlice[zMinComputed - zInput0:zMaxComputed - zInput0 + 1,
                                                                 yMinComputed - yInput0:yMaxComputed - yInput0 + 1,
                                                                 xMinComputed - xInput0:xMaxComputed - xInput0 + 1]
                resultComputation = cls.applyOnPatch(dataInputTSliceComputationArea, dataOutputTSlice,
                                                     isOnFullSpatialChannel, dictBooleanArguments, dictNumericArguments,
                                                     dictStringArguments, progressId)

                if resultComputation is not None:
                    dataOutputTSlice[zMin - zOutput0:zMax - zOutput0 + zOutputSize,
                                     yMin - yOutput0:yMax - yOutput0 + 1,
                                     xMin - xOutput0:xMax - xOutput0 + 1] = resultComputation[
                                                           zMin - zMinComputed:zMax - zMinComputed + zOutputSize,
                                                           yMin - yMinComputed:yMax - yMinComputed + 1,
                                                           xMin - xMinComputed:xMax - xMinComputed + 1]

    @classmethod
    def applyOnPatch(cls, dataInputTSlice, dataOutputTSlice, isOnFullSpatialChannel, dictBooleanArguments,
                     dictNumericArguments, dictStringArguments, progressId):
        IProgress = orsObj(progressId)
        
        extra_keywords = {'progressId': progressId,
                          'classCount': int(dictNumericArguments['classCount'])}

        result = cls._apply_parallel(cls.patchJob, dataInputTSlice, isOnFullSpatialChannel=isOnFullSpatialChannel,
                                     extra_keywords=extra_keywords)
        if IProgress is not None and IProgress.getIsCancelled():
            return None
        if isOnFullSpatialChannel:
            dataOutputTSlice[:] = result[:]
            return None
        else:
            return result
    
    @classmethod
    def patchJob(cls, dataInputTSlice, classCount, progressId):
        zI = dataInputTSlice.shape[0]
        result = dataInputTSlice.copy()
        IProgress = orsObj(progressId)
        for zindex in range(0, zI):
            if IProgress is not None and IProgress.getIsCancelled():
                result[:] = 0  # Clearing all result
                return result
            result[zindex, ...] = cls._classifyOnSlice(dataInputTSlice[zindex, :, :], classCount)
        return result

    @classmethod
    def _classifyOnSlice(cls, dataInputZSlice, classCount):
        classificationGMM = GaussianMixture(n_components=classCount)
        classificationGMM.fit(dataInputZSlice.reshape((dataInputZSlice.size, 1)))

        resultClassificationOnSlice = np.zeros(dataInputZSlice.shape, dtype='uint8')  # Initialization
        # Computing the limits of each class by taking the middle of each adjacent gaussian mean.
        # A cleaner separation might be obtained if the standard deviation of each gaussian is used,
        # to identify the local minima of these adjacent gaussians.
        gaussianMeans = np.sort(classificationGMM.means_.flatten())
        gaussianSeparations = (gaussianMeans[:-1] + gaussianMeans[1:])/2
        gaussianSeparationsWithInfinites = np.concatenate([np.array([-np.inf]),
                                                           gaussianSeparations,
                                                           np.array([np.inf])])

        for classIndex in range(classCount):
            inferiorLimit = gaussianSeparationsWithInfinites[classIndex]
            superiorLimit = gaussianSeparationsWithInfinites[classIndex + 1]
            resultClassificationOnSlice[np.logical_and(dataInputZSlice >= inferiorLimit,
                                                       dataInputZSlice < superiorLimit)] = classIndex

        return resultClassificationOnSlice

    @classmethod
    def getSuggestedOutputDataType(cls, listInputChannelId, outputChannelIndex):
        IChannelInput = orsObj(listInputChannelId[0])
        iDataTypeInput = IChannelInput.getDataType()

        # The GMM filter produces classification result, starting at 0 and ending at 255, each class
        # having an index.
        # Therefore, the suggested output data type is unsigned char.
        return CxvChannel_Data_Type.CXVCHANNEL_DATA_TYPE_UNSIGNED_BYTE

    @classmethod
    def getLengthDependenceX(cls, inputChannelIndex, numpyKernel, dictBooleanArguments, dictNumericArguments,
                             dictStringArguments):
        """
        Method to ask the filter for the extent required (in X) to perform computations on a subset of the dataset
        :param inputChannelIndex: int
            Channel index for which the size dependence is requested
        :return: int
            Largest number of pixels required in X from the pixel of computation (either side).
            If all the pixels are required, return -1.
        """
        return -1  # The filter requires all the pixels on the slice to be computed

    @classmethod
    def getLengthDependenceY(cls, inputChannelIndex, numpyKernel, dictBooleanArguments, dictNumericArguments,
                             dictStringArguments):
        """
        Method to ask the filter for the extent required (in Y) to perform computations on a subset of the dataset
        :param inputChannelIndex: int
            Channel index for which the size dependence is requested
        :return: int
            Largest number of pixels required in Y from the pixel of computation (either side).
            If all the pixels are required, return -1.
        """
        return -1  # The filter requires all the pixels on the slice to be computed

    @classmethod
    def getLengthDependenceZ(cls, inputChannelIndex, numpyKernel, dictBooleanArguments, dictNumericArguments,
                             dictStringArguments):
        return 0  # 2D filter. Adjacent slices are not required to compute the output of the filter for a given slice.

    @classmethod
    def getShowKernelShape(cls):
        """
        :return: bool
            Return True if the kernel shape (square, circle, ...) should be visible in the UI, False otherwise.
        """

        return False

    @classmethod
    def getShowKernelDim(cls):
        """
        :return: bool
            Return True if the kernel dimensionality (2D, 3D) should be visible in the UI, False otherwise.
        """

        return False

    @classmethod
    def getShowKernelSize(cls):
        """
        :return: bool
            Return True if the kernel size (3, 5, 7, ...) should be visible in the UI, False otherwise.
        """

        return False

    @classmethod
    def getNumericArguments(cls):
        """
        :return: dict
            Accepted arguments of numerical type, with their default values
        """

        return {'classCount': 2}
    
    @classmethod
    def getIsNumericArgumentValid(cls, argumentName, argumentValue):
        """
        :param argumentName: str
            Argument name
        :param argumentValue: numeric
            Argument value
        :return: bool
            Returns True if the argument value (for the argument name) is valid, False otherwise.
        """

        if not FilterUtil.getIsArgumentValidAsNumeric(argumentValue):
            return False

        if argumentName == 'classCount':
            # classCount should be an integer between 2 and 255 inclusive
            if int(argumentValue) != argumentValue:
                return False
            if not (2 <= int(argumentValue) <= 255):
                return False
            return True

        return False
    
    @classmethod
    def getArgumentDescriptors(cls):
        """
        :return: list
            List of argument descriptors (see class "ArgumentDescriptor"),
            in the order wanted for the display in the UI.
        """

        descriptorclassCount = FilterAbstract.ArgumentDescriptorNumeric()  # Instantiation
        descriptorclassCount.setArgumentName('classCount')
        descriptorclassCount.setLabel('Class count:')
        descriptorclassCount.setDecimalCount(0)
        descriptorclassCount.setMinimumValue(2)
        descriptorclassCount.setMaximumValue(255)

        return [descriptorclassCount]
