#!/usr/bin/env python
# (C) 2024 OpenEye Scientific Software Inc. All rights reserved.
#
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of OpenEye products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. OpenEye claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable OpenEye offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall OpenEye be
# liable for any damages or liability in connection with the Sample Code
# or its use.


'''
This script performs calculation of shape and PB potential similarity. At first it overlays the reference and
fit molecules with shape and charge overlay, and later selected top hits are rescored with shape and PB potential similarity.

Usage:
python eon_pb_overlay.py -query query.sdf -in database_molecules.oeb.gz -out output.oeb.gz

python eon_pb_overlay.py -query query.sdf -in database_molecules.oeb.gz -out output.oeb.gz -numtophits 2
'''

from openeye import oechem
from openeye import oeet
from openeye import oeshape
from openeye import oequacpac


class EonPBOverlayOptions(oechem.OEOptions):
    def __init__(self):
        oechem.OEOptions.__init__(self, "EonPBOverlayOptions")
        self._hitOpts = oeet.ToEonHitlistOptions(self.AddOption(oeet.OEEonHitlistOptions()))
        self._pbOpts = oeet.ToEonPBOptions(self.AddOption(oeet.OEEonPBOptions()))

        chargesParam = oechem.OEStringParameter("-charges", "mmff94")
        chargesParam.SetRequired(True)
        chargesParam.SetVisibility(oechem.OEParamVisibility_Simple)
        chargesParam.SetBrief("Choices for this parameter are mmff94 (fix pka and set charges to mmff94), or existing (assume input has charges)")
        self._chargesParam = self.AddParameter(chargesParam)

        numTopHitsParam = oechem.OEUIntParameter("-numtophits", 5)
        numTopHitsParam.SetIsList(False)
        numTopHitsParam.SetRequired(True)
        numTopHitsParam.SetVisibility(oechem.OEParamVisibility_Simple)
        numTopHitsParam.SetBrief("Number of shape charge overlay hits to rescore with shape and PB potential similarity")
        self._numTopHitsParam = self.AddParameter(numTopHitsParam)
        pass

    def CreateCopy(self):
        return self

    def GetCharges(self):
        value = self._chargesParam.GetStringValue()
        if value != "":
            return value
        return self._chargesParam.GetStringDefault()

    def GetNumTopHits(self):
        if self._numTopHitsParam.GetHasValue():
            return int(self._numTopHitsParam.GetStringValue())
        return int(self._numTopHitsParam.GetStringDefault())

    def  GetHitlistOptions(self):
        return self._hitOpts

    def  GetPBOptions(self):
        return self._pbOpts


def prep_mol(inmol):
    # remove color atoms if present
    if oeshape.OEHasColorAtoms(inmol):
        oeshape.OERemoveColorAtoms(inmol)

    # add explicit hydrogens and assign radii
    oechem.OEAddExplicitHydrogens(inmol)
    oechem.OEAssignBondiVdWRadii(inmol)

    oequacpac.OESetNeutralpHModel(inmol)
    oequacpac.OEAssignCharges(inmol, oequacpac.OEMMFF94Charges())

    inmol.Sweep()


def main(argv=[__name__]):

    eonOpts = EonPBOverlayOptions()
    opts = oechem.OERefInputAppOptions(
        eonOpts, "EonPBOverlayOptions",
        oechem.OEFileStringType_Mol3D,
        oechem.OEFileStringType_Mol3D,
        oechem.OEFileStringType_Mol3D, "-query")

    if oechem.OEConfigureOpts(opts, argv, False) == oechem.OEOptsConfigureStatus_Help:
        return 0

    eonOpts.UpdateValues(opts)

    qname       = opts.GetRefFile()
    iname       = opts.GetInFile()
    oname       = opts.GetOutFile()
    charges     = eonOpts.GetCharges()
    numtophits  = eonOpts.GetNumTopHits()

    qfs = oechem.oemolistream()
    if not qfs.open(qname):
        oechem.OEThrow.Fatal(f'Unable to open {qname} for reading')

    ifs = oechem.oemolistream()
    if not ifs.open(iname):
        oechem.OEThrow.Fatal(f'Unable to open {iname} for reading')
    
    ofs = oechem.oemolostream()
    if not ofs.open(oname):
        oechem.OEThrow.Fatal(f'Unable to open {oname} for writing')

    # setup eon calculation
    shapeOpts = oeshape.OEShapeOptions()
    shapeOpts.SetScoreType(oeshape.OEOverlapResultType_Tanimoto)

    # set up the Shape/Charge Overlap Func

    shapeFunc = oeshape.OEGridShapeFunc()
    chargeFunc = oeet.OEGridChargeFunc()
    overlapFunc = oeet.OEEonOverlapFunc(shapeFunc, chargeFunc)
    overlayOpts = oeshape.OEOverlayOptions()
    overlayOpts.SetOverlapFunc(overlapFunc)

    eon_pb_overlay_obj = oeet.OEEonPBOverlay(overlayOpts, eonOpts.GetPBOptions())

    # set up query
    refmol = oechem.OEMol()
    oechem.OEReadMolecule(qfs, refmol)

    if (charges == "mmff94"):
        prep_mol(refmol)
    eon_pb_overlay_obj.SetupRef(refmol)

    # write query to output file
    oechem.OEWriteMolecule(ofs, refmol)
    hitlist = oeet.OEEonHitlistBuilder(eonOpts.GetHitlistOptions())

    for mol in ifs.GetOEMols():
        if (charges == "mmff94"):
            prep_mol(mol)
        score = oeet.OEEonOverlayScore()
        eon_pb_overlay_obj.BestOverlay(score, mol, numtophits)
        hitlist.AddScore(score, mol)

    hitlist.Build()

    for hit in hitlist.GetHits():
        hitMol = hit.GetMol()
        oechem.OESetSDData(hitMol, 'Shape Tanimoto', f'{hit.GetTanimoto():.4f}')
        oechem.OESetSDData(hitMol, 'Potential Tanimoto', f'{hit.GetETTanimoto():.4f}')
        oechem.OESetSDData(hitMol, 'Shape + Potential Tanimoto Combo', f'{hit.GetTanimotoCombo():.4f}')
        oechem.OEWriteMolecule(ofs, hitMol)

if __name__ == '__main__':
    import sys
    sys.exit(main(sys.argv))