#!/usr/bin/env python
# (C) 2022 Cadence Design Systems, Inc. (Cadence) 
# All rights reserved.
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of Cadence products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. Cadence claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable Cadence offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall Cadence be
# liable for any damages or liability in connection with the Sample Code
# or its use.



'''
This script performs fitting an input grid to a vector of gaussians

Usage:
python grid_to_gaussians_fit.py -in input_mol.oeb

'''

from openeye import oechem, oeshape, oegrid


class GridToGaussiansOptions(oechem.OEOptions):
    def __init__(self):
        oechem.OEOptions.__init__(self, "GridToGaussiansOptions")
        gridResolutionParam = oechem.OEFloatParameter("-gridresolution", 0.5)
        gridResolutionParam.SetIsList(False)
        gridResolutionParam.SetRequired(True)
        gridResolutionParam.SetVisibility(oechem.OEParamVisibility_Simple)
        gridResolutionParam.SetBrief("Grid resolution for converting input molecule into a grid query")
        self._gridResolution = self.AddParameter(gridResolutionParam)
        
        self._shapeQueryOpts = oeshape.ToShapeQueryOptions(self.AddOption(oeshape.OEShapeQueryOptions()))

    def CreateCopy(self):
        return self

    def GetGridResolution(self):

        if self._gridResolution.GetHasValue():
            return float(self._gridResolution.GetStringValue())
        
        return float(self._gridResolution.GetStringDefault())

    def GetShapeQueryOptions(self):
        return self._shapeQueryOpts


def main(argv=[__name__]):
    genOpts = GridToGaussiansOptions()

    opts = oechem.OESimpleAppOptions(
        genOpts, "GridToGaussiansOptions",
        oechem.OEFileStringType_Mol3D)

    if oechem.OEConfigureOpts(opts, argv, False) == oechem.OEOptsConfigureStatus_Help:
        return 0

    genOpts.UpdateValues(opts)

    inputFilename           = opts.GetInFile()
    resolution              = genOpts.GetGridResolution()
    outputGridFileName      = "{}.grd".format(inputFilename.split(".")[0])
    outputFilename          = "{}-grid-to-gaussians-query.sq".format(inputFilename.split(".")[0])
    outputOEBFilename       = "{}-grid-to-gaussians-mol.oeb".format(inputFilename.split(".")[0])

    ifs = oechem.oemolistream()
    if not ifs.open(inputFilename):
        oechem.OEThrow.Fatal(f'Unable to open {inputFilename} for reading')


    mol = oechem.OEMol()
    oechem.OEReadMolecule(ifs, mol)

    oechem.OESuppressHydrogens(mol)
    mol.Sweep()

    grid = oegrid.OEScalarGrid()
    oegrid.OEMakeMolecularGaussianGrid(grid, mol, resolution)
    oegrid.OEWriteGrid(outputGridFileName, grid)

    sq = oeshape.OEShapeQuery(genOpts.GetShapeQueryOptions())

    sq.AddShapeGaussians(grid)

    coords = []
    outMol = oechem.OEMol()
    for gauss in sq.GetShapeGaussians():
        coords.append(gauss.GetX())
        coords.append(gauss.GetY())
        coords.append(gauss.GetZ())
        outMol.NewAtom(oechem.OEElemNo_C)

    outMol.SetCoords(oechem.OEFloatArray(coords))

    oeshape.OEWriteShapeQuery(outputFilename, sq)

    with oechem.oemolostream(outputOEBFilename) as ofs:
        oechem.OEWriteMolecule(ofs, outMol)

    ifs.close()

if __name__ == '__main__':
    import sys
    sys.exit(main(sys.argv))