/****************************************************************************
 * Copyright (c) 2007 Einir Valdimarsson and Chrysanthe Preza
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 ****************************************************************************/

#ifndef _ESTIMATE_CG_ML_POISSON_H
#define _ESTIMATE_CG_ML_POISSON_H

#include "blitz/fftwInterface.h"
#include <blitz/array.h>

#include <string>

namespace cosm {

template<typename T>
Array<T, 1> csampler(
    int size
){
    T fac = M_PI/(size - 1);

    Array<T, 1> cs(size);
    cs = 0;

    for (int b = 1; b<=size; b++)
    {
        cs(b-1) = 0.5 + 0.5*cos((size-b)*fac);
    }

    return cs;
};

//cosine apodizer
template<typename T, int N>
Array<T, N> apodize(
    Array<T, N>& img,
    int Bz,
    int Bx,
    int By
){
    int Z = img.length(0);
    int X = img.length(1);
    int Y = img.length(2);

    Array<T, N> out(Z+2*Bz, X+2*Bx, Y+2*By);
    out = 0;

    out(Range(Bz, Bz+Z-1), Range(Bx, Bx+X-1), Range(By, By+Y-1)) = 
        img(Range::all(), Range::all(), Range::all());

    int Zo = out.length(0);
    int Xo = out.length(1);
    int Yo = out.length(2);

    T m = mean(img);

    //Adding Borders on Rows
    Array<T, 1>csX=csampler<T>(Bx);

    //Here we are using '0' but not '1' as the value of first index because the variables (x, y, z)
    // were used to represent index but not a conceptual value. So using '0' cause of C-style language
    for (int z = 0; z<Z; z++)
    {
        for (int y = 0; y<Y; y++)
        {
            T diffL = img(z, 0, y) - m;
            T diffR = img(z, X-1, y) - m;

            for (int x = 0; x<Bx; x++)
            {
                out(Bz+z, x, By+y) = diffL*csX(x) + m;
                out(Bz+z, X+Bx+x, By+y) = diffR*csX(Bx-x-1) + m;
            }

        }
    }

    //Adding Borders on Columns
    Array<T, 1>csY = csampler<T>(By);

    //Here we are using '0' but not '1' as the value of first index because the variables (x, y, z)
    // were used to represent index but not a conceptual value. So using '0' cause of C-style language
    for (int z = 0; z<Z; z++)
    {
        for (int x = 0; x<Xo; x++)
        {
            T diffT = out(Bz+z, x, By) - m;
            T diffB = out(Bz+z, x, Y+By-1) - m;

            for (int y = 0; y<By; y++)
            {
                out(Bz+z, x, y) = diffT*csY(y) + m;
                out(Bz+z, x, By+Y+y) = diffB*csY(By-y-1) + m;
            }

        }
    }

    //Adding Borders on Z
    Array<T, 1>csZ = csampler<T>(Bz);

    //Here we are using '0' but not '1' as the value of first index because the variables (x, y, z)
    // were used to represent index but not a conceptual value. So using '0' cause of C-style language
    for (int y = 0; y<Yo; y++)
    {
        for (int x = 0; x<Xo; x++)
        {
            T diffU = out(Bz, x, y) - m;
            T diffD = out(Z+Bz-1, x, y) - m;

            for (int z = 0; z<Bz; z++)
            {
                out(z, x, y) = diffU*csZ(z) + m;
                out(Z+Bz+z, x, y) = diffD*csZ(Bz-z-1) + m;
            }

        }
    }

    return out;
};

// Same to 'ftconv.m'
// x     : input matrix
// h     : kernel matrix
// y     : result matrix with size(x) equal to size(y)
template<typename T, int N>
void ftconv(
    const Array<T, N>& x,
    const Array<T, N>& h,
    Array <T, N>& y,
    bool adjoint = false
){
    Array<T, N> img = x;
    Array<T, N> psf = h;

    TinyVector<int,N> sizeImg = img.length();
    TinyVector<int,N> sizePsf = psf.length();

    // Resize (enlarge) image and Psf
    //for img, padding by csampler algo
    TinyVector<int,N> halfSize = ceil((sizePsf - sizeImg)/2);
    Array<T,N> padImg = apodize<T, N>(img, halfSize(0), halfSize(1), halfSize(2));

    //Conv operation
    Array<std::complex<T>,N> Afft = cosm::forwardFFT(padImg);
    Array<std::complex<T>,N> Bfft = cosm::forwardFFT(psf);

    if (adjoint){
        Bfft = conj(Bfft);
    }

    Bfft /= Bfft(0, 0, 0);

    Array<std::complex<T>,N> fftCov(Afft.extent());
    fftCov = Afft * Bfft;

    Array<T,N> newImg = cosm::inverseFFT(fftCov);

    //Shrink
    Array<T,N> shrinkedImg (sizeImg);

    TinyVector<int,N> newSizeImg = newImg.length();

    TinyVector<int,N> bc = ceil((newSizeImg - sizeImg)/2);

    RectDomain<N> imgScope(bc, bc+sizeImg-1);

    shrinkedImg = newImg(imgScope);

    y = shrinkedImg;

};

// Same to 'dvftconv.m'
// x     : input matrix
// h     : stratum kernel's h(y,x,z,stratum)
// y     : result matrix with size(x) equal to size(y)
template<typename T, int N>
void dvftconv(
    const Array<T, N>& x,
    const Array<Array<T,N>, 1>& h,
    Array <T, N>& y,
    bool adjoint = false
){

    //cout << "--Start dvftconv... "<<endl;

    Array<T, N> img = x;
    Array<Array<T,N>, 1> psfs = h;
    int numberOfStrata = psfs.extent(0) - 1;

    Array<T,N> y1(img.extent());
    Array<T,N> y2(img.extent());
    y1 = 0;
    y2 = 0;

    int startSlide = 1;    //Of each stratum in Matlab stype
    int endSlide;    //Of each stratum in Matlab stype

    int numberOfSlides = img.extent(0); //total number of z-slices

    if (adjoint){// the overlap-save method
        //cout << "--The overlap-save method "<<endl;

        ftconv<T, N>(img, psfs(0), y1, adjoint);    //conv. with PSF 1


        for (int i = 2; i <=numberOfStrata+1; i++ ){//Through each stratum

            ftconv<T, N>(img, psfs(i-1), y2, adjoint);    //conv. with PSF 2 ff.

            endSlide = floor(double((i-1)*numberOfSlides/numberOfStrata));

            if (endSlide > numberOfSlides){
                endSlide = numberOfSlides;
            }

            int size = endSlide - startSlide + 1;

            Array<T,1> w (size);

            w = 0;

            for (int j = 0; j < size; j++){
                w(j) = T(j)/(T(size) - T(1));
            }

            //copy stratum (weighted PSF 1+2) in result
            for (int j = startSlide - 1; j <= endSlide - 1; j++){
                int i = j - startSlide + 1;
                y(j, Range::all(), Range::all()) = y2(j, Range::all(), Range::all()) * w(i) + 
                    y1(j, Range::all(), Range::all()) * (1 - w(i));
            }

            startSlide = endSlide + 1;

            y1 = y2;
        }
    } else{ // the overlap-add method
        //cout << "--The overlap-add method "<<endl;

        Array<T,N> x1(img.extent());
        Array<T,N> x2(img.extent());
        x1 = 0;
        x2 = 0;

        Array<T,1> w1 (numberOfSlides);
        Array<T,1> w2 (numberOfSlides);

        for (int i = 1; i <=numberOfStrata; i++ ){//Through each stratum

            endSlide = floor(double(i*numberOfSlides/numberOfStrata));

            if (endSlide > numberOfSlides){
                endSlide = numberOfSlides;
            }

            w1 = 0;
            w2 = 0;

            int w_lower = startSlide - 1;
            int size = endSlide - startSlide + 1;

            for (int j = 0; j < size; j++){
                w1(w_lower + j) = T(1) - T(j)/(T(size) - T(1));
                w2(w_lower + j) = T(j)/(T(size) - T(1));
            }

            //Only supports 3D now
            for (int j = 0; j < numberOfSlides; j++){
                x1(j, Range::all(),Range::all()) = img(j, Range::all(),Range::all()) * w1(j);
                x2(j, Range::all(),Range::all()) = img(j, Range::all(),Range::all()) * w2(j);
            }
                
            startSlide = endSlide + 1;

            y1 = 0;
            y2 = 0;

            ftconv<T, N>(x1, psfs(i-1), y1, adjoint);
            ftconv<T, N>(x2, psfs(i), y2, adjoint);

            y = y + y1 + y2;

        }
    }

    //cout << "--End dvftconv. " <<endl;
};

//Internal function 'PDiv'
template<typename T, int N>
void PDiv(
    const Array<T, N>& x,
    const Array<T, N>& y,
    Array <T, N>& dv
){
    dv = where(((x == 0) || (y == 0)), 1, x/y);
};

//Internal function 'GradL'
template<typename T, int N>
void GradL(
    const Array<T, N>& Kss,
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const Array<Array<T,N>, 1>& h,
    Array <T, N>& gradL1
){
    Array<T, N> y(Kss.extent()), img(Kss.extent()), est(Kss.extent());
    img = 0;
    est = 0;

    y = Kss+b;
    PDiv<T, N>(g, y, img);
    img = 1- img;

    dvftconv<T, N>(img, h, est, true); // true == with 'adjoint'

    gradL1 = -2 * s * est;
};

//Internal function 'GradR'
template<typename T, int N>
void GradR(
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    Array <T, N>& gradR
){
    Array<T, N> tmp(s.extent());
    tmp = 0;
    
    tmp = pow2(s) - g + b;
    gradR = -4 * s * tmp;

};

//Internal function 'Grad'
template<typename T, int N>
void Grad(
    const Array<T, N>& Kss,
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const T beta,
    const Array<Array<T,N>, 1>& h,
    Array <T, N>& grad1
){
    Array<T, N> gradL(Kss.extent()), gradR(Kss.extent());
    gradL = 0;
    gradR = 0;

    GradL<T, N>(Kss, s, g, b, h, gradL);

    GradR<T, N>(s, g, b, gradR);
    
    grad1 = gradL + beta * gradR;

};

//Internal function 'DHessL'
template<typename T, int N>
T DHessL(
    const Array<T, N>& Kss,
    const Array<T, N>& Ksdk,
    const Array<T, N>& Kdd,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const T lambda
){
    T dHessL = 0;

    Array<T, N> a(Kss.extent()), bb(Kss.extent()), c(Kss.extent()), tmp(Kss.extent()), x(Kss.extent());
    a = 0;
    bb = 0;
    c = 0;
    tmp = 0;
    x = 0;

    a = Kss + lambda * (2 * Ksdk + lambda * Kdd) + b;

    PDiv<T, N>(g, a, tmp);

    bb = 2 * Kdd * (1 - tmp);

    x = Ksdk + lambda * Kdd;

    tmp = 0;

    PDiv<T, N>(x, a, tmp);

    c = 4 * g * pow2(tmp);

    dHessL = sum(bb) + sum(c);

    return dHessL;
};

//Internal function 'DHessR'
template<typename T, int N>
T DHessR(
    const Array<T, N>& s,
    const Array<T, N>& dk,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const T lambda
){
    T dHessR = 0;

    Array<T, N> sl(dk.extent()), hs(dk.extent());
    sl = 0;
    hs = 0;

    sl = s + lambda * dk;

    hs = 4* dk * dk * (3 * pow2(sl) - g + b);

    dHessR = sum (hs);

    return dHessR;
};

//Internal function 'DHess'
template<typename T, int N>
T DHess(
    const Array<T, N>& Kss,
    const Array<T, N>& Ksdk,
    const Array<T, N>& Kdd,
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const Array<T, N>& dk,
    const T beta,
    const T lambda
){
    T dHess = 0;

    dHess = DHessL<T, N>(Kss, Ksdk, Kdd, g, b, lambda) + beta * DHessR<T, N>(s, dk, g, b, lambda);

    return dHess;
};

//Internal function 'DGradL'
template<typename T, int N>
T DGradL(
    const Array<T, N>& Kss,
    const Array<T, N>& Ksdk,
    const Array<T, N>& Kdd,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const T lambda
){
    T dGradL = 0;

    Array<T, N> tmp(Kss.extent()), y(Kss.extent()), hs(Kss.extent());
    tmp = 0;
    y = 0;
    hs = 0;

    y = (Kss + lambda * (2 * Ksdk + lambda * Kdd)) + b;

    PDiv<T, N>(g, y, tmp);

    hs = 2 * (Ksdk + lambda * Kdd) * (1 - tmp);

    dGradL = sum(hs);

    return dGradL;
};

//Internal function 'DGradR'
template<typename T, int N>
T DGradR(
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const Array<T, N>& dk,
    const T lambda
){
    T dGradR = 0;

    Array<T, N> a(s.extent()), gr(s.extent());
    a = 0;
    gr = 0;

    a = pow2((s + lambda * dk));

    gr = 4 * dk * s * (a - g + b);

    dGradR = sum(gr);

    return dGradR;
};

//Internal function 'DGrad'
template<typename T, int N>
T DGrad(
    const Array<T, N>& Kss,
    const Array<T, N>& Ksdk,
    const Array<T, N>& Kdd,
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const Array<T, N>& dk,
    const T beta,
    const T X
){
    T dGrad = 0;

    dGrad = DGradL<T, N>(Kss, Ksdk, Kdd, g, b, X) + beta * DGradR<T, N>(s, g, b, dk, X);

    return dGrad;
};

//Internal function 'newkss'
template<typename T, int N>
void newkss(
    const Array<T, N>& oldkss,
    const Array<T, N>& ksdk,
    const Array<T, N>& kdd,
    const T lambda,
    Array <T, N>& y
){
    y = oldkss + 2 * lambda * ksdk + lambda * lambda * kdd;
};

//Internal function 'NewtonRaphson'
template<typename T, int N>
T NewtonRaphson(
    const Array<T, N>& Kss,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const Array<T, N>& s,
    const Array<T, N>& dk,
    const T beta,
    const Array<Array<T,N>, 1>& h,
    Array <T, N>& kss,
    bool adjoint
){
    Array<T, N> Ksdk(s.extent()), Kdd(s.extent()), tmp(s.extent());
    Ksdk = 0;
    Kdd = 0;
    tmp = 0;

    T dGrad = 0, dHess = 0, lambda = 0;

    tmp = s * dk;

    dvftconv<T, N>(tmp, h, Ksdk, adjoint);

    tmp = 0;
    tmp = pow2(dk);
    dvftconv<T, N>(tmp, h, Kdd, adjoint);

    for (int n = 1; n <= 3; n++){
        dGrad = DGrad<T, N>(Kss, Ksdk, Kdd, s, g, b, dk, beta, lambda);
        dHess = DHess<T, N>(Kss, Ksdk, Kdd, s, g, b, dk, beta, lambda);
        lambda = lambda - dGrad / dHess;
    }

    newkss<T, N>(Kss, Ksdk, Kdd, lambda,kss);

    return lambda;
};

//Internal function 'NewtonRaphson_SI'
template<typename T, int N>
T NewtonRaphson_SI(
    const Array<T, N>& Kss,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const Array<T, N>& s,
    const Array<T, N>& dk,
    const T beta,
    const Array<T, N>& h,
    Array <T, N>& kss
){
    Array<T, N> Ksdk(s.extent()), Kdd(s.extent()), tmp(s.extent());
    Ksdk = 0;
    Kdd = 0;
    tmp = 0;

    T dGrad = 0, dHess = 0, lambda = 0;

    tmp = s * dk;

    ftconv<T, N>(tmp, h, Ksdk);

    tmp = 0;
    tmp = pow2(dk);
    ftconv<T, N>(tmp, h, Kdd);

    for (int n = 1; n <= 3; n++){
        dGrad = DGrad<T, N>(Kss, Ksdk, Kdd, s, g, b, dk, beta, lambda);
        dHess = DHess<T, N>(Kss, Ksdk, Kdd, s, g, b, dk, beta, lambda);
        lambda = lambda - dGrad / dHess;
    }

    newkss<T, N>(Kss, Ksdk, Kdd, lambda,kss);

    return lambda;
};

//Internal function 'GradL_SI'
template<typename T, int N>
void GradL_SI(
    const Array<T, N>& Kss,
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const Array<T, N>& h,
    Array <T, N>& gradL1
){
    Array<T, N> y(Kss.extent()), img(Kss.extent()), est(Kss.extent());
    img = 0;
    est = 0;

    y = Kss+b;
    PDiv<T, N>(g, y, img);
    img = 1- img;

    ftconv<T, N>(img, h, est, true); // true == with 'adjoint'

    gradL1 = -2 * s * est;
};

//Internal function 'Grad_SI'
template<typename T, int N>
void Grad_SI(
    const Array<T, N>& Kss,
    const Array<T, N>& s,
    const Array<T, N>& g,
    const Array<T, N>& b,
    const T beta,
    const Array<T, N>& h,
    Array <T, N>& grad1
){
    Array<T, N> gradL(Kss.extent()), gradR(Kss.extent());
    gradL = 0;
    gradR = 0;

    GradL_SI<T, N>(Kss, s, g, b, h, gradL);

    GradR<T, N>(s, g, b, gradR);
    
    grad1 = gradL + beta * gradR;

};

};
#endif // _ESTIMATE_CG_ML_POISSON_H
