#!/usr/bin/env python3
# (C) 2017 OpenEye Scientific Software Inc. All rights reserved.
#
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of OpenEye products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. OpenEye claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable OpenEye offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall OpenEye be
# liable for any damages or liability in connection with the Sample Code
# or its use.

#############################################################################
# Depicts molecules with fragment highlights
#############################################################################

import sys
import math
from openeye import oechem
from openeye import oedepict
from openeye import oemedchem


def main(argv=[__name__]):

    itf = oechem.OEInterface(InterfaceData)
    oedepict.OEConfigure2DMolDisplayOptions(itf)
    oedepict.OEConfigureReportOptions(itf)

    if not oechem.OEParseCommandLine(itf, argv):
        return 1

    iname = itf.GetString("-in")
    oname = itf.GetString("-out")

    pagebypage = itf.GetBool("-pagebypage")

    # check input/output files

    ifs = oechem.oemolistream()
    if not ifs.open(iname):
        oechem.OEThrow.Fatal("Cannot open input file!")

    ext = oechem.OEGetFileExtension(oname)
    if not pagebypage and not oedepict.OEIsRegisteredMultiPageImageFile(ext):
        oechem.OEThrow.Warning("Report will be generated into separate pages!")
        pagebypage = True
    ext = oechem.OEGetFileExtension(oname)

    # initialize fragmentation function

    fragfunc = GetFragmentationFunction(itf)

    # initialize multi-page report

    ropts = oedepict.OEReportOptions()
    oedepict.OESetupReportOptions(ropts, itf)
    ropts.SetFooterHeight(25.0)
    ropts.SetHeaderHeight(ropts.GetPageHeight() / 3.0)
    report = oedepict.OEReport(ropts)

    # setup depiction options

    moldispopts = oedepict.OE2DMolDisplayOptions()
    oedepict.OESetup2DMolDisplayOptions(moldispopts, itf)
    cellwidth, cellheight = report.GetHeaderWidth(), report.GetHeaderHeight()
    moldispopts.SetDimensions(cellwidth, cellheight, oedepict.OEScale_AutoScale)
    moldispopts.SetAtomColorStyle(oedepict.OEAtomColorStyle_WhiteMonochrome)
    moldispopts.SetAtomLabelFontScale(1.3)

    fragdispopts = oedepict.OE2DMolDisplayOptions()
    oedepict.OESetup2DMolDisplayOptions(fragdispopts, itf)
    fragdispopts.SetTitleLocation(oedepict.OETitleLocation_Hidden)
    fragdispopts.SetAtomLabelFontScale(1.3)

    # read molecules

    mollist = []
    for mol in ifs.GetOEGraphMols():
        mollist.append(oechem.OEGraphMol(mol))

    # depict molecules with fragments

    DepictMoleculesWithFragments(report, mollist, fragfunc,
                                 moldispopts, fragdispopts)

    if pagebypage:
        oedepict.OEWriteReportPageByPage(oname, report)
    else:
        oedepict.OEWriteReport(oname, report)

    return 0


def DepictMoleculesWithFragments(report, mollist, fragfunc,
                                 moldispopts, fragdispopts):

    for mol in mollist:

        body = report.NewBody()
        oedepict.OEPrepareDepiction(mol)
        header = report.GetHeader(report.NumPages())

        # loop over input molecule and fragment

        fragsets = [f for f in fragfunc(mol)]
        fragmols = []
        for fset in fragsets:
            fragment = oechem.OEGraphMol()
            fragpred = oechem.OEIsAtomMember(fset.GetAtoms())
            adjustHCount = True
            oechem.OESubsetMol(fragment, mol, fragpred, adjustHCount)
            fragmols.append(oechem.OEGraphMol(fragment))

        nrfrags = len(fragmols)
        colorg = oechem.OELinearColorGradient(oechem.OEColorStop(0, oechem.OEYellowTint),
                                              oechem.OEColorStop(nrfrags - 1, oechem.OEDarkOrange))

        # render molecule with fragment highlights

        cellwidth, cellheight = report.GetHeaderWidth(), report.GetHeaderHeight()
        moldispopts.SetDimensions(cellwidth, cellheight, oedepict.OEScale_AutoScale)

        disp = oedepict.OE2DMolDisplay(mol, moldispopts)
        for fidx, fset in enumerate(fragsets):
            color = colorg.GetColorAt(fidx)
            oedepict.OEAddHighlighting(disp, color, oedepict.OEHighlightStyle_BallAndStick, fset)

        oedepict.OERenderMolecule(header, disp)

        # create fragment grid

        rows = max(2, int(math.sqrt(nrfrags + 1)))
        cols = max(2, int(nrfrags / rows) + 1)
        grid = oedepict.OEImageGrid(body, rows, cols)
        grid.SetCellGap(8.0)

        cellwidth, cellheight = grid.GetCellWidth(), grid.GetCellHeight()
        fragdispopts.SetDimensions(cellwidth, cellheight, oedepict.OEScale_AutoScale)
        fragdispopts.SetTitleLocation(oedepict.OETitleLocation_Hidden)

        # determine the scale factor to depict fragments with equal size

        minscale = oedepict.OEGetMoleculeScale(mol, fragdispopts) * 1.25
        for frag in fragmols:
            minscale = min(minscale, oedepict.OEGetMoleculeScale(frag, fragdispopts))
        fragdispopts.SetScale(minscale)

        # render each fragments

        for fidx, (cell, fmol) in enumerate(zip(grid.GetCells(), fragmols)):
            oedepict.OEPrepareDepiction(fmol)
            disp = oedepict.OE2DMolDisplay(fmol, fragdispopts)
            oedepict.OERenderMolecule(cell, disp)

            color = colorg.GetColorAt(fidx)
            pen = oedepict.OEPen(oechem.OEWhite, color, oedepict.OEFill_Off, 3.0)
            oedepict.OEDrawBorder(cell, pen)


def GetFragmentationFunction(itf):

    fstring = itf.GetString("-fragtype")
    if fstring == "funcgroup":
        return oemedchem.OEGetFuncGroupFragments
    if fstring == "ring-chain":
        return oemedchem.OEGetRingChainFragments
    return oemedchem.OEGetRingLinkerSideChainFragments


#############################################################################
# INTERFACE
#############################################################################

InterfaceData = '''
!CATEGORY "input/output options"
    !PARAMETER -in
      !ALIAS -i
      !TYPE string
      !REQUIRED true
      !KEYLESS 1
      !VISIBILITY simple
      !BRIEF Input filename
    !END

    !PARAMETER -out
      !ALIAS -o
      !TYPE string
      !REQUIRED true
      !KEYLESS 2
      !VISIBILITY simple
      !BRIEF Output filename
    !END
!END

!CATEGORY "fragmentation options"
    !PARAMETER -fragtype
      !ALIAS -ftype
      !TYPE string
      !REQUIRED false
      !KEYLESS 3
      !DEFAULT funcgroup
      !LEGAL_VALUE funcgroup
      !LEGAL_VALUE ring-chain
      !LEGAL_VALUE ring-linker-sidechain
      !VISIBILITY simple
      !BRIEF Fragmentation type
    !END
!END

!CATEGORY "report options"

    !PARAMETER -pagebypage
      !ALIAS -p
      !TYPE bool
      !REQUIRED false
      !DEFAULT false
      !VISIBILITY simple
      !BRIEF Write individual numbered separate pages
    !END

!END
'''
if __name__ == "__main__":
    sys.exit(main(sys.argv))
