#!/usr/bin/env python3
# (C) 2017 OpenEye Scientific Software Inc. All rights reserved.
#
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of OpenEye products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. OpenEye claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable OpenEye offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall OpenEye be
# liable for any damages or liability in connection with the Sample Code
# or its use.

#############################################################################
# Depicts the B-factor of a ligand and its environment
#############################################################################

import sys
from openeye import oechem
from openeye import oedepict
from openeye import oegrapheme


def main(argv=[__name__]):

    itf = oechem.OEInterface()
    oechem.OEConfigure(itf, InterfaceData)
    oedepict.OEConfigureImageWidth(itf, 600.0)
    oedepict.OEConfigureImageHeight(itf, 400.0)
    oedepict.OEConfigure2DMolDisplayOptions(itf, oedepict.OE2DMolDisplaySetup_AromaticStyle)
    oechem.OEConfigureSplitMolComplexOptions(itf, oechem.OESplitMolComplexSetup_LigName)

    if not oechem.OEParseCommandLine(itf, argv):
        return 1

    iname = itf.GetString("-complex")
    oname = itf.GetString("-out")

    maxdist = itf.GetFloat("-maxdist")

    ifs = oechem.oemolistream()
    if not ifs.open(iname):
        oechem.OEThrow.Fatal("Cannot open input file!")

    ext = oechem.OEGetFileExtension(oname)
    issvg = ext == "svg"

    if not oedepict.OEIsRegisteredImageFile(ext):
        oechem.OEThrow.Fatal("Unknown image type!")

    ofs = oechem.oeofstream()
    if not ofs.open(oname):
        oechem.OEThrow.Fatal("Cannot open output file!")

    complexmol = oechem.OEGraphMol()
    if not oechem.OEReadMolecule(ifs, complexmol):
        oechem.OEThrow.Fatal("Unable to read molecule from %s" % iname)

    if not oechem.OEHasResidues(complexmol):
        oechem.OEPerceiveResidues(complexmol, oechem.OEPreserveResInfo_All)

    # separate ligand and protein

    sopts = oechem.OESplitMolComplexOptions()
    oechem.OESetupSplitMolComplexOptions(sopts, itf)

    ligand = oechem.OEGraphMol()
    protein = oechem.OEGraphMol()
    water = oechem.OEGraphMol()
    other = oechem.OEGraphMol()

    oechem.OESplitMolComplex(ligand, protein, water, other, complexmol, sopts)

    if ligand.NumAtoms() == 0:
        oechem.OEThrow.Fatal("Cannot separate complex!")

    # calculate average BFactor of the whole complex

    avgbfactor = get_average_bfactor(complexmol)

    # calculate minimum and maximum BFactor of the ligand and its environment

    minbfactor, maxbfactor = get_min_and_max_bfactor(ligand, protein, maxdist)

    # attach to each ligand atom the average BFactor of the nearby protein atoms

    stag = "avg residue BFfactor"
    itag = oechem.OEGetTag(stag)
    set_average_bfactor_of_nearby_protein_atoms(ligand, protein, itag, maxdist)

    print("Avg B-factor of the complex = %+.2f" % avgbfactor)
    print("Min B-factor of the ligand and its environment (%.1fA) = %+.2f" % (maxdist, minbfactor))
    print("Max B-factor of the ligand and its environment (%.1fA) = %+.2f" % (maxdist, maxbfactor))

    iwidth, iheight = oedepict.OEGetImageWidth(itf), oedepict.OEGetImageHeight(itf)
    image = oedepict.OEImage(iwidth, iheight)

    mframe = oedepict.OEImageFrame(image, iwidth, iheight * 0.85,
                                   oedepict.OE2DPoint(0.0, 0.0))
    lframe = oedepict.OEImageFrame(image, iwidth, iheight * 0.15,
                                   oedepict.OE2DPoint(0.0, iheight * 0.85))

    colorg = get_bfactor_color_gradient()

    opts = oedepict.OE2DMolDisplayOptions(mframe.GetWidth(), mframe.GetHeight(),
                                          oedepict.OEScale_AutoScale)
    oedepict.OESetup2DMolDisplayOptions(opts, itf)
    opts.SetAtomColorStyle(oedepict.OEAtomColorStyle_WhiteMonochrome)
    opts.SetTitleLocation(oedepict.OETitleLocation_Hidden)

    depict_bfactor(mframe, ligand, opts, colorg, itag, issvg)
    depict_color_gradient(lframe, colorg, minbfactor, maxbfactor, avgbfactor)

    if ext == 'svg':
        iconscale = 0.5
        oedepict.OEAddInteractiveIcon(image, oedepict.OEIconLocation_TopRight, iconscale)
    oedepict.OEDrawCurvedBorder(image, oedepict.OELightGreyPen, 10.0)

    oedepict.OEWriteImage(oname, image)

    return 0


#############################################################################
#
#############################################################################


def depict_bfactor(image, ligand, opts, colorg, itag, issvg):
    """
    :type image: oedepict.OEImageBase
    :type ligand: oechem.OEMolBase
    :type opts: oechem.OE2DMolDisplayOptions
    :type colorg: oechem.OELinearColorGradient
    :type itag: int
    :type issvg: boolean
    """

    # prepare ligand for depiction

    oegrapheme.OEPrepareDepictionFrom3D(ligand)

    clearcoords, suppressH = False, False
    popts = oedepict.OEPrepareDepictionOptions(clearcoords, suppressH)
    popts.SetDepictOrientation(oedepict.OEDepictOrientation_Horizontal)
    oedepict.OEPrepareDepiction(ligand, popts)

    arcfxn = BFactorArcFxn(colorg, itag)
    for atom in ligand.GetAtoms():
        oegrapheme.OESetSurfaceArcFxn(ligand, atom, arcfxn)
    opts.SetScale(oegrapheme.OEGetMoleculeSurfaceScale(ligand, opts))

    # render ligand and visualize BFactor

    disp = oedepict.OE2DMolDisplay(ligand, opts)

    if issvg:
        font = oedepict.OEFont(oedepict.OEFontFamily_Default, oedepict.OEFontStyle_Default,
                               14, oedepict.OEAlignment_Center, oechem.OEBlack)
        for adisp in disp.GetAtomDisplays():
            atom = adisp.GetAtom()
            if not oechem.OEHasResidue(atom):
                continue
            res = oechem.OEAtomGetResidue(atom)
            hovertext = "bfactor=%.1f" % res.GetBFactor()
            oedepict.OEDrawSVGHoverText(disp, adisp, hovertext, font)

    colorbfactor = ColorLigandAtomByBFactor(colorg)
    oegrapheme.OEAddGlyph(disp, colorbfactor, oechem.OEIsTrueAtom())

    oegrapheme.OEDraw2DSurface(disp)

    oedepict.OERenderMolecule(image, disp)


def depict_color_gradient(image, colorg, minbfactor, maxbfactor, avgbfactor):

    opts = oegrapheme.OEColorGradientDisplayOptions()
    opts.SetColorStopPrecision(1)
    opts.AddMarkedValue(avgbfactor)
    opts.SetBoxRange(minbfactor, maxbfactor)

    oegrapheme.OEDrawColorGradient(image, colorg, opts)


def get_average_bfactor(mol):
    nratoms, sumbfactor = 0, 0.0

    for atom in mol.GetAtoms():
        if not oechem.OEHasResidue(atom):
            continue
        res = oechem.OEAtomGetResidue(atom)
        sumbfactor += res.GetBFactor()
        nratoms += 1

    avgbfactor = sumbfactor / nratoms
    return avgbfactor


class NotHydrogenOrWater(oechem.OEUnaryAtomPred):
    def __call__(self, atom):

        if atom.GetAtomicNum() == oechem.OEElemNo_H:
            return False
        if not oechem.OEHasResidue(atom):
            return False

        waterpred = oechem.OEIsWater()
        if waterpred(atom):
            return False

        return True


def get_min_and_max_bfactor(ligand, protein, maxdistance):

    minbfactor, maxbfactor = float("inf"), float("-inf")

    # ligand atoms

    for latom in ligand.GetAtoms(oechem.OEIsHeavy()):
        if not oechem.OEHasResidue(latom):
            continue
        res = oechem.OEAtomGetResidue(latom)
        minbfactor = min(minbfactor, res.GetBFactor())
        maxbfactor = max(maxbfactor, res.GetBFactor())

    # protein atoms close to ligand atoms
    considerbfactor = NotHydrogenOrWater()

    nn = oechem.OENearestNbrs(protein, maxdistance)
    for latom in ligand.GetAtoms(oechem.OEIsHeavy()):
        for neigh in nn.GetNbrs(latom):
            ratom = neigh.GetBgn()

            if considerbfactor(ratom):
                res = oechem.OEAtomGetResidue(ratom)
                minbfactor = min(minbfactor, res.GetBFactor())
                maxbfactor = max(maxbfactor, res.GetBFactor())

    return minbfactor, maxbfactor


def set_average_bfactor_of_nearby_protein_atoms(ligand, protein, itag, maxdistance):

    considerbfactor = NotHydrogenOrWater()

    nn = oechem.OENearestNbrs(protein, maxdistance)
    for latom in ligand.GetAtoms(oechem.OEIsHeavy()):
        sumbfactor = 0.0
        neighs = []
        for neigh in nn.GetNbrs(latom):
            ratom = neigh.GetBgn()
            if considerbfactor(ratom):
                res = oechem.OEAtomGetResidue(ratom)
                sumbfactor += res.GetBFactor()
                neighs.append(ratom)

        avgbfactor = 0.0
        if len(neighs) > 0:
            avgbfactor = sumbfactor / len(neighs)
        latom.SetDoubleData(itag, avgbfactor)


def get_bfactor_color_gradient():

    colorg = oechem.OELinearColorGradient()
    colorg.AddStop(oechem.OEColorStop(0.0, oechem.OEDarkBlue))
    colorg.AddStop(oechem.OEColorStop(10.0, oechem.OELightBlue))
    colorg.AddStop(oechem.OEColorStop(25.0, oechem.OEYellowTint))
    colorg.AddStop(oechem.OEColorStop(50.0, oechem.OERed))
    colorg.AddStop(oechem.OEColorStop(100.0, oechem.OEDarkRose))

    return colorg


class BFactorArcFxn(oegrapheme.OESurfaceArcFxnBase):
    def __init__(self, colorg, itag):
        oegrapheme.OESurfaceArcFxnBase.__init__(self)
        self.colorg = colorg
        self.itag = itag

    def __call__(self, image, arc):
        adisp = arc.GetAtomDisplay()
        if adisp is None or not adisp.IsVisible():
            return False

        atom = adisp.GetAtom()
        if atom is None:
            return False

        avgresiduebfactor = atom.GetDoubleData(self.itag)
        if avgresiduebfactor == 0.0:
            return True
        color = self.colorg.GetColorAt(avgresiduebfactor)

        pen = oedepict.OEPen(color, color, oedepict.OEFill_Off, 5.0)

        center = arc.GetCenter()
        bAngle, eAngle = arc.GetBgnAngle(), arc.GetEndAngle()
        radius = arc.GetRadius()

        oegrapheme.OEDrawDefaultSurfaceArc(image, center, bAngle, eAngle, radius, pen)

        return True

    def CreateCopy(self):
        return BFactorArcFxn(self.colorg, self.itag).__disown__()


class ColorLigandAtomByBFactor(oegrapheme.OEAtomGlyphBase):
    def __init__(self, colorg):
        oegrapheme.OEAtomGlyphBase.__init__(self)
        self.colorg = colorg

    def RenderGlyph(self, disp, atom):
        adisp = disp.GetAtomDisplay(atom)
        if adisp is None or not adisp.IsVisible():
            return False

        if not oechem.OEHasResidue(atom):
            return False

        res = oechem.OEAtomGetResidue(atom)
        bfactor = res.GetBFactor()
        color = self.colorg.GetColorAt(bfactor)

        pen = oedepict.OEPen(color, color, oedepict.OEFill_On, 1.0)
        radius = disp.GetScale() / 3.0

        layer = disp.GetLayer(oedepict.OELayerPosition_Below)
        circlestyle = oegrapheme.OECircleStyle_Default
        oegrapheme.OEDrawCircle(layer, circlestyle, adisp.GetCoords(), radius, pen)
        return True

    def CreateCopy(self):
        return ColorLigandAtomByBFactor(self.colorg).__disown__()


#############################################################################
# INTERFACE
#############################################################################

InterfaceData = '''
!BRIEF [-complex] <input> [-out] <output pdf>

!CATEGORY "input/output options :" 1

  !PARAMETER -complex 1
    !ALIAS -c
    !TYPE string
    !KEYLESS 1
    !REQUIRED true
    !VISIBILITY simple
    !BRIEF Input filename of the protein-ligand complex
  !END

  !PARAMETER -out 2
    !ALIAS -o
    !TYPE string
    !REQUIRED true
    !KEYLESS 2
    !VISIBILITY simple
    !BRIEF Output filename of the generated image
  !END

!END

!CATEGORY "B-factor options :" 2

  !PARAMETER -max_distance
    !ALIAS -maxdist
    !TYPE float
    !REQUIRED false
    !DEFAULT 4.0
    !LEGAL_RANGE 2.0 8.0
    !VISIBILITY simple
    !BRIEF Maximum distance of receptor atoms
    !DETAIL
        When visualizing the B-factor of the environment of the ligand only
        protein atoms that are closer to any ligand atom than this limit are considered
  !END
'''

if __name__ == "__main__":
    sys.exit(main(sys.argv))
