#! /usr/bin/env python
# S.Rodney
# 2010.06.12
# v2 2010.09.16 : bug fixes. Thanks to Hai Fu and Hanindyo Kuncarayakti
"""
Convert an IFU fits file from Euro3d format (fits bin tables) 
to a traditional IFU fits data cube (3-d data array) and a 
white light image (2-d image array).

Syntax:  e3d2cube.py infile outroot
  ==> outroot_cube.fits, outroot_white.fits

"""


import sys
import pyfits

def main():
    import getopt
    clobber=False

    # read in arguments and options
    try:
        opt,arg = getopt.getopt( 
            sys.argv[1:],"h",
            longopts=["help","clobber" ] )
    except getopt.GetoptError: 
        print __doc__
        return(-1)

    for o, a in opt:
        if o in ["-h", "--help"]:
            print __doc__
            return(0)
        elif o=='--clobber' :
            clobber=True
    if len(arg)<2 : 
        print __doc__
        return(0)
    infile = arg[0]
    outroot = arg[1]

    # convert from e3d to a data cube fits array
    cube, white = convert( infile )  
    cube.writeto( outroot+"_cube.fits", clobber=clobber )
    cube.writeto( outroot+"_white.fits", clobber=clobber )


def convert( infile ) :
    """ returns a pyfits HDUlist """
    from numpy import arange, array, zeros

    e3d = pyfits.open( infile )
    phdr = e3d['PRIMARY'].header
    dhdr = e3d['E3D_DATA'].header
    data = e3d['E3D_DATA'].data
    
    w0 = dhdr['CRVALS']  # reference wavelength
    try : dw = dhdr['CRDELTS']  # wavelength step
    except : dw = dhdr['CDELTS']  # wavelength step
    wunit = dhdr['CTYPES']  # wavelength unit

    # define the dimensions of the spaxel array 
    Nx, Ny, x0, y0, dx, dy  = getspaxdim( data )
    Nw = data[0]['SPEC_LEN']    # number of wave. steps
    
    # # initialize the primary HDU
    # phdu = pyfits.PrimaryHDU( header=phdr )
    # phdu.header.update("NEXTEND",2)
       
    # initialize an empty 3-d cube (zero everywhere)
    #cube = pyfits.ImageHDU()
    cube = pyfits.PrimaryHDU()
    cube.header.update('NAXIS',3)
    cube.header.update('NAXIS1',Nx,after='NAXIS')
    cube.header.update('NAXIS2',Ny,after='NAXIS1')
    cube.header.update('NAXIS3',Nw,after='NAXIS2')
    cube.header.update('CD1_1',dx)
    cube.header.update('CD2_2',dy)
    cube.header.update('CD3_3',dw)
    cube.header.update('CRPIX1',0)
    cube.header.update('CRPIX2',0)
    cube.header.update('CRPIX3',0)
    cube.header.update('CRVAL1',x0)
    cube.header.update('CRVAL2',y0)
    cube.header.update('CRVAL3',w0)

    cube.header.update('CTYPE1','ARCSEC')
    cube.header.update('CTYPE2','ARCSEC')
    cube.header.update('CTYPE3','ANGSTROM')

    cube.header.update('CD1_2',0)
    cube.header.update('CD1_3',0)
    cube.header.update('CD2_1',0)
    cube.header.update('CD2_3',0)
    cube.header.update('CD3_1',0)
    cube.header.update('CD3_2',0)

    #cube.data = zeros( (Nx,Ny,Nw) )
    cube.data = zeros( (Nw,Nx,Ny) )

    # extract each spectrum and place it
    # into the 3-d cube
    for ispec in range(len(data)): 
        Nwoff = data[ispec]['SPEC_STA']  # starting offset, in wave. steps
        #wave = arange( w0+Nwoff*dw, w0+(Nwoff+Nw)*dw, dw )
        spec = data[ispec]['DATA_SPE']
        Nwspec = data[ispec]['SPEC_LEN']    

        xpos = data[ispec]['XPOS']    
        ypos = data[ispec]['YPOS']    
        ix = int( round((xpos - x0),2) / dx  )
        iy = int( round((ypos - y0),2) / dy  )
        for i in range( -1*Nwoff, min(Nw-Nwoff,Nwspec) ):
            #cube.data[ix][iy][i+Nwoff] = spec[i]
            cube.data[i+Nwoff][iy][ix] = spec[i]


    # make a 2-d white light image    
    # white = pyfits.ImageHDU( data=sum( cube.data, 0 ) )
    white = pyfits.PrimaryHDU( data=sum( cube.data, 0 ) )
    white.header.update('NAXIS1',Nx,after='NAXIS')
    white.header.update('NAXIS2',Ny,after='NAXIS1')
    white.header.update('CD1_1',dx)
    white.header.update('CD2_2',dy)
    white.header.update('CD1_2',0)
    white.header.update('CD2_1',0)
    white.header.update('CRPIX1',0)
    white.header.update('CRPIX2',0)
    white.header.update('CRVAL1',x0)
    white.header.update('CRVAL2',y0)
    white.header.update('CTYPE1','ARCSEC')
    white.header.update('CTYPE2','ARCSEC')

    ## put together the HDU list 
    #ifuHDU = pyfits.HDUList( hdus=[phdu,cube,white] )
    return( cube, white )



def getspaxdim( data ):
    """
    define the dimensions of the spaxel array. Return values are
      Nx : number of spaxels in the x direction
      Ny : number of spaxels in the y direction
      dx : arcseconds between adjacent spaxels in the x direction
      dy : arcseconds between adjacent spaxels in the y direction

    The E3D format doesn't have any explicit parameters defining the
    spaxel array size, and it is possible that some spaxels are not
    included in the data arrays, so we have to read in all the x and
    y positions (in arcseconds relative to the center), find the
    physical separation between spaxel centers, and use that to
    define the array.

    Isn't there a better way to do this??
    """
    from numpy import arange, array

    # reduce a list to just the unique elements
    uniquelist = lambda x : dict(map(lambda i: (i,1),x)).keys()

    Nspec = len(data)
    xpos, ypos = [], []
    # collect all the physical spaxel locations
    # rounded to the nearest 0.01 arcseconds
    for ispec in range(Nspec): 
        xpos.append( round(data[ispec][5],2) )
        ypos.append( round(data[ispec][6],2) )
    # reduce to lists of just the unique x,y positions
    xpos = sorted( uniquelist( xpos ) )
    ypos = sorted( uniquelist( ypos ) )

    # find the separation between positions, 
    # rounded to the nearest 0.01 arcseconds
    xsteps = [ round(xpos[i]-xpos[i-1],2) for i in range(1,len(xpos)) ]
    ysteps = [ round(ypos[i]-ypos[i-1],2) for i in range(1,len(ypos)) ]
    # reduce to lists of unique separations
    xsteps = sorted( uniquelist( xsteps ) )
    ysteps = sorted( uniquelist( ysteps ) )
    
    # we probably have a single value in x and y now,
    # and for symmetric spaxels these are probably the same.
    # We take the minimum as defining the unit of separation
    # between adjacent spaxels
    dxspax = min( xsteps )
    dyspax = min( ysteps )
    
    # Finally, we assume the spaxel array spans the 
    # total x and y spatial extent with NXSPAX and NYSPAX
    # equally spaced spaxels.  Those dimensions are returned.
    NXSPAX = int( (max(xpos) - min(xpos))/dxspax ) + 1 
    NYSPAX = int( (max(ypos) - min(ypos))/dyspax ) + 1

    return( NXSPAX, NYSPAX, min(xpos), min(ypos), dxspax, dyspax )


if __name__ == "__main__":
    main()
