#!/usr/bin/env python3
# (C) 2017 OpenEye Scientific Software Inc. All rights reserved.
#
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of OpenEye products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. OpenEye claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable OpenEye offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall OpenEye be
# liable for any damages or liability in connection with the Sample Code
# or its use.

#############################################################################
#
#############################################################################

import sys
import math

from openeye import oechem
from openeye import oedepict
from openeye import oegrapheme
from openeye import oemolprop


def main(argv=[__name__]):

    itf = oechem.OEInterface()
    oechem.OEConfigure(itf, InterfaceData)
    oedepict.OEConfigureImageWidth(itf, 600.0)
    oedepict.OEConfigureImageHeight(itf, 400.0)
    o = oedepict.OE2DMolDisplaySetup_AromaticStyle | oedepict.OE2DMolDisplaySetup_TitleLocation
    oedepict.OEConfigure2DMolDisplayOptions(itf, o)

    if not oechem.OEParseCommandLine(itf, argv):
        return 1

    iname = itf.GetString('-in')
    oname = itf.GetString('-out')
    colorgradient = itf.GetBool('-colorgradient')

    ext = oechem.OEGetFileExtension(oname)
    if not oedepict.OEIsRegisteredImageFile(ext):
        oechem.OEThrow.Fatal("Unknown image type!")

    ifs = oechem.oemolistream()
    if not ifs.open(iname):
        oechem.OEThrow.Fatal("Cannot open input file!")

    mol = oechem.OEGraphMol()
    if not oechem.OEReadMolecule(ifs, mol):
        oechem.OEThrow.Fatal("Cannot read input file!")

    width, height = oedepict.OEGetImageWidth(itf), oedepict.OEGetImageHeight(itf)
    opts = oedepict.OE2DMolDisplayOptions(width, height, oedepict.OEScale_AutoScale)
    oedepict.OESetup2DMolDisplayOptions(opts, itf)
    opts.SetMargin(oedepict.OEMargin_Right, 25.0)
    opts.SetMargin(oedepict.OEMargin_Bottom, 20.0)
    opts.SetBondWidthScaling(True)

    clearcoords, suppressH = True, True
    popts = oedepict.OEPrepareDepictionOptions(clearcoords, suppressH)
    popts.SetDepictOrientation(oedepict.OEDepictOrientation_Horizontal)
    oedepict.OEPrepareDepiction(mol, popts)

    propdisps = get_property_displays()
    set_properties(mol, propdisps)

    image = oedepict.OEImage(width, height)
    render_properties(image, mol, opts, propdisps, colorgradient)

    if ext == 'svg':
        iconscale = 0.5
        oedepict.OEAddInteractiveIcon(image, oedepict.OEIconLocation_TopRight, iconscale)
    oedepict.OEDrawCurvedBorder(image, oedepict.OELightGreyPen, 10.0)

    oedepict.OEWriteImage(oname, image)

    return 0


class OEPropertyDisplay():
    def __init__(self, id, label, filtername, colorg, valuedict=None):
        self._id = id
        self._label = label
        self._filtername = filtername
        self._colorg = oechem.OELinearColorGradient(colorg)
        self._origvalue = None
        self._value = None
        self._valuedict = valuedict

    def __str__(self):
        return '{} = {}'.format(self._label, self._value)

    def GetId(self):
        return self._id

    def GetLabel(self):
        return self._label

    def GetFilterName(self):
        return self._filtername

    def SetValue(self, value):
        self._origvalue = value
        if self._valuedict is not None:
            value = self._valuedict[value]
        value = float(value)
        assert(isinstance(value, float))
        self._value = value

    def GetOrigValue(self):
        return self._origvalue

    def GetValue(self):
        return self._value

    def GetColorGradient(self):
        return self._colorg

    def GetColor(self):
        if self.GetValue() is None:
            return oechem.OEWhite
        return self._colorg.GetColorAt(self.GetValue())

    def GetPiePen(self):
        color = self.GetColor()
        return oedepict.OEPen(color, color, oedepict.OEFill_On, 1.0)

    def GetLabelBorderPen(self):
        color = self.GetColor()
        return oedepict.OEPen(oechem.OEWhite, color, oedepict.OEFill_On, 3.0)


def render_properties(image, mol, opts, properties, colorgradient):

    disp = oedepict.OE2DMolDisplay(mol, opts)

    oedepict.OERenderMolecule(image, disp)

    imagew, imageh = image.GetWidth(), image.GetHeight()

    pframesize = opts.GetMargin(oedepict.OEMargin_Right) * imagew / 100.0
    pframe = oedepict.OEImageFrame(image, pframesize, pframesize,
                                   oedepict.OE2DPoint(imagew - pframesize * 1.1,
                                                      imageh - pframesize * 1.1))

    cframew = imagew - pframe.GetWidth() * 1.25
    cframeh = opts.GetMargin(oedepict.OEMargin_Bottom) * 0.75 * imageh / 100.0

    cframe = oedepict.OEImageFrame(image, cframew, cframeh,
                                   oedepict.OE2DPoint(10.0, imageh - cframeh - 10.0))

    center = oedepict.OEGetCenter(pframe)
    radius = pframesize / 2.0

    incangle = 360.0 / len(properties)
    bgnangle, endangle = 0.0, incangle

    group_prefix = 'property_hover_'
    for prop in properties:
        area_group = image.NewSVGGroup(prop.GetId())
        hover_group = image.NewSVGGroup(group_prefix + prop.GetId())
        oedepict.OEAddSVGHover(area_group, hover_group)

        image.PushGroup(area_group)
        draw_property_pie(pframe, prop, center, bgnangle, bgnangle + incangle, radius)
        image.PopGroup(area_group)

        image.PushGroup(hover_group)
        labeltext = '{} = {}'.format(prop.GetLabel(), prop.GetOrigValue())
        if colorgradient:
            draw_property_color_gradient(cframe, prop, labeltext)
        else:
            draw_property_label(image, prop, labeltext)
        image.PopGroup(hover_group)

        bgnangle, endangle = endangle, endangle + incangle


def draw_property_color_gradient(image, prop, labeltext):
    opts = oegrapheme.OEColorGradientDisplayOptions()
    opts.SetColorStopPrecision(1)
    opts.SetColorStopLabelFontScale(0.5)
    opts.AddLabel(oegrapheme.OEColorGradientLabel(prop.GetValue(), labeltext))
    oegrapheme.OEDrawColorGradient(image, prop.GetColorGradient(), opts)


def draw_property_label(image, prop, labeltext):

    font = oedepict.OEFont(oedepict.OEFontFamily_Default, oedepict.OEFontStyle_Bold, 18,
                           oedepict.OEAlignment_Center, oechem.OEBlack)
    label = oedepict.OEHighlightLabel(labeltext, font)
    label.SetFont(font)
    label.SetBoundingBoxPen(prop.GetLabelBorderPen())
    labelpos = oedepict.OE2DPoint(image.GetWidth() / 2.0, image.GetHeight() - 25.0)
    oedepict.OEAddLabel(image, labelpos, label)


def draw_property_pie(image, prop, center, bgnangle, endangle, radius):

    pos = get_pie_label_position(center, bgnangle, endangle, radius / 10.0)

    shadowoffset = oedepict.OE2DPoint(3.0, 3.0)

    image.DrawPie(pos + shadowoffset, bgnangle, endangle, radius, oedepict.OELightGreyBoxPen)
    image.DrawPie(pos, bgnangle, endangle, radius, prop.GetPiePen())

    labeltext = prop.GetId()
    fontsize = int(radius / 5)
    font = oedepict.OEFont(oedepict.OEFontFamily_Default, oedepict.OEFontStyle_Bold, fontsize,
                           oedepict.OEAlignment_Center, oechem.OEBlack)

    label = oedepict.OEHighlightLabel(labeltext, font)
    label.SetBoundingBoxPen(oedepict.OETransparentPen)

    textcenter = get_pie_label_position(center, bgnangle, endangle, radius)
    oedepict.OEAddLabel(image, textcenter, label)


def get_pie_label_position(center, beginangle, endangle, radius):

    p = oedepict.OE2DPoint(0.0, -radius / 1.5)
    midangle = (beginangle + endangle) / 2.0
    rad = math.radians(midangle)
    cosrad = math.cos(rad)
    sinrad = math.sin(rad)
    return center + oedepict.OE2DPoint(cosrad * p.GetX() - sinrad * p.GetY(),
                                       sinrad * p.GetX() + cosrad * p.GetY())


def set_properties(mol, properties):

    ifs = oechem.oeisstream(get_filter_rules())
    filter = oemolprop.OEFilter(ifs)

    level = oechem.OEThrow.GetLevel()
    oechem.OEThrow.SetLevel(oechem.OEErrorLevel_Warning)

    ostr = oechem.oeosstream()
    pwnd = False
    filter.SetTable(ostr, pwnd)

    headers = ostr.str().split(b'\t')
    ostr.clear()  # remove the header row from the stream

    filter(mol)

    fields = ostr.str().decode("UTF-8").split('\t')
    ostr.clear()  # remove this row from the stream

    filterdict = dict(zip(headers, fields))

    for prop in properties:
        if prop.GetFilterName() in filterdict:
            prop.SetValue(filterdict[prop.GetFilterName()])

    oechem.OEThrow.SetLevel(level)


def get_filter_rules():

    FILTER_RULES = """
# This file defines the rules for filtering multi-structure files based on
# properties and substructure patterns.

MIN_MOLWT      130       "Minimum molecular weight"
MAX_MOLWT      781       "Maximum molecular weight"

MIN_XLOGP      -3.0      "Minimum XLogP"
MAX_XLOGP       6.85     "Maximum XLogP"

PSA_USE_SandP   false    "Count S and P as polar atoms"
MIN_2D_PSA      0.0      "Minimum 2-Dimensional (SMILES) Polar Surface Area"
MAX_2D_PSA      205.0    "Maximum 2-Dimensional (SMILES) Polar Surface Area"

# choices are insoluble<poorly<moderately<soluble<very<highly
MIN_SOLUBILITY insoluble "Minimum solubility"

MIN_LIPINSKI_DONORS  0      "Minimum number of hydrogens on O & N atoms"
MAX_LIPINSKI_DONORS  6      "Maximum number of hydrogens on O & N atoms"

MIN_LIPINSKI_ACCEPTORS  1   "Minimum number of oxygen & nitrogen atoms"
MAX_LIPINSKI_ACCEPTORS  14  "Maximum number of oxygen & nitrogen atoms"

MAX_LIPINSKI   3         "Maximum number of Lipinski violations"
"""
    return FILTER_RULES


def get_property_displays():

    soldict = {'insoluble': 1, 'poorly': 2, 'moderately': 3, 'soluble': 4, 'very': 5, 'highly': 6}

    propdisps = list()
    propdisps.append(OEPropertyDisplay('MW', 'Molecular weight', b'molecular weight', get_mol_weight_color_gradient()))
    propdisps.append(OEPropertyDisplay('XLogP', 'XLogP', b'XLogP', get_XLogP_color_gradient()))
    propdisps.append(OEPropertyDisplay('TPSA', 'Topological PSA', b'2d PSA', get_TPSA_color_gradient()))
    propdisps.append(OEPropertyDisplay('Sol', 'Solubility', b'Solubility', get_solubility_color_gradient(), soldict))
    propdisps.append(OEPropertyDisplay('Rof5', 'Rule of five', b'Lipinski violations', get_Lipinski_color_gradient()))

    return propdisps


def get_mol_weight_color_gradient():

    molwt_mean = 314.0
    molwt_sigma = 128

    colorg = oechem.OELinearColorGradient()
    colorg.AddStop(oechem.OEColorStop(0.0, oechem.OERed))
    colorg.AddStop(oechem.OEColorStop(molwt_mean / 2.0, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(molwt_mean, oechem.OEGreen))
    colorg.AddStop(oechem.OEColorStop(molwt_mean + 1.5 * molwt_sigma, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(molwt_mean + 3 * molwt_sigma, oechem.OERed))

    return colorg


def get_XLogP_color_gradient():

    logp_mean = 2.472
    logp_sigma = 2.013

    colorg = oechem.OELinearColorGradient()
    colorg.AddStop(oechem.OEColorStop(logp_mean - 3 * logp_sigma, oechem.OERed))
    colorg.AddStop(oechem.OEColorStop(logp_mean - 1.5 * logp_sigma, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(logp_mean, oechem.OEGreen))
    colorg.AddStop(oechem.OEColorStop(logp_mean + 1.5 * logp_sigma, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(logp_mean + 3 * logp_sigma, oechem.OERed))

    return colorg


def get_TPSA_color_gradient():

    tpsa_mean = 62.8
    tpsa_sigma = 50.5

    colorg = oechem.OELinearColorGradient()
    colorg.AddStop(oechem.OEColorStop(0.0, oechem.OERed))
    colorg.AddStop(oechem.OEColorStop(tpsa_mean / 2.0, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(tpsa_mean, oechem.OEGreen))
    colorg.AddStop(oechem.OEColorStop(tpsa_mean + 1.5 * tpsa_sigma, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(tpsa_mean + 3 * tpsa_sigma, oechem.OERed))

    return colorg


def get_solubility_color_gradient():

    colorg = oechem.OELinearColorGradient()
    colorg.AddStop(oechem.OEColorStop(1.0, oechem.OERed))
    colorg.AddStop(oechem.OEColorStop(3.5, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(6.0, oechem.OEGreen))

    return colorg


def get_Lipinski_color_gradient():

    colorg = oechem.OELinearColorGradient()
    colorg.AddStop(oechem.OEColorStop(0.0, oechem.OEGreen))
    colorg.AddStop(oechem.OEColorStop(2.0, oechem.OEYellow))
    colorg.AddStop(oechem.OEColorStop(3.0, oechem.OERed))

    return colorg


InterfaceData = """
!CATEGORY "input/output options"

    !PARAMETER -in
      !ALIAS -i
      !TYPE string
      !REQUIRED true
      !KEYLESS 1
      !VISIBILITY simple
      !BRIEF Input filename
    !END

    !PARAMETER -out
      !ALIAS -o
      !TYPE string
      !REQUIRED true
      !KEYLESS 2
      !VISIBILITY simple
      !BRIEF Output filename
    !END

!END

!CATEGORY "display options"

    !PARAMETER -colorgradient
      !ALIAS -colorg
      !TYPE bool
      !DEFAULT false
      !VISIBILITY simple
      !BRIEF Display property value on color gradient
    !END

!END
"""

if __name__ == "__main__":
    sys.exit(main(sys.argv))
