#!/usr/bin/env python
# (C) 2022 Cadence Design Systems, Inc. (Cadence) 
# All rights reserved.
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of Cadence products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. Cadence claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable Cadence offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall Cadence be
# liable for any damages or liability in connection with the Sample Code
# or its use.

import os
import sys

from openeye import oechem
from openeye import oefastrocs

oepy = os.path.join(os.path.dirname(__file__), "..", "python")
sys.path.insert(0, os.path.realpath(oepy))


def main(argv=[__name__]):
    if len(argv) < 4:
        oechem.OEThrow.Usage("%s <database> <queries> <hits.oeb>" % argv[0])
        return 0

    # check system
    if not oefastrocs.OEFastROCSIsGPUReady():
        oechem.OEThrow.Info("No supported GPU available!")
        return 0

    # read in database
    dbname = argv[1]
    if oechem.OEIsGZip(dbname):
        oechem.OEThrow.Fatal("%s is an unsupported database file format as it is gzipped.\n"
                             "Preferred formats are .oeb, .sdf or .oez", dbname)

    print("Opening database file %s ..." % dbname)
    dbase = oefastrocs.OEShapeDatabase()
    moldb = oechem.OEMolDatabase()

    if not moldb.Open(dbname):
        oechem.OEThrow.Fatal("Unable to open '%s'" % dbname)

    dots = oechem.OEThreadedDots(10000, 200, "conformers")
    if not dbase.Open(moldb, dots):
        oechem.OEThrow.Fatal("Unable to initialize OEShapeDatabase on '%s'" % dbname)

    # customize search options
    opts = oefastrocs.OEShapeDatabaseOptions()
    opts.SetInitialOrientation(oefastrocs.OEFastROCSOrientation_UserInertialStarts)

    opts.SetLimit(5)

    qfname = argv[2]
    # read in query
    qfs = oechem.oemolistream()
    if not qfs.open(qfname):
        oechem.OEThrow.Fatal("Unable to open '%s'" % qfname)

    query = oechem.OEGraphMol()
    if not oechem.OEReadMolecule(qfs, query):
        oechem.OEThrow.Fatal("Unable to read query from '%s'" % qfname)

    # write out everthing to a similary named file
    ofs = oechem.oemolostream()
    if not ofs.open(argv[3]):
        oechem.OEThrow.Fatal("Unable to open '%s'" % argv[3])
    oechem.OEWriteMolecule(ofs, query)

    startsCoords = oechem.OEFloatVector()
    atomIdx = 1
    xyz = query.GetCoords()[atomIdx]
    for x in xyz:
        startsCoords.append(x)
    if not len(startsCoords) % 3 == 0:
        oechem.OEThrow.Fatal("Something went wrong whilst reading in user-starts coordinates")

    opts.SetUserStarts(oechem.OEFloatVector(startsCoords), int(len(startsCoords)/3))

    opts.SetMaxOverlays(opts.GetNumInertialStarts() * opts.GetNumUserStarts())

    if opts.GetInitialOrientation() == oefastrocs.OEFastROCSOrientation_UserInertialStarts:
        numStarts = opts.GetNumUserStarts()
        print("This example will use %u starts" % numStarts)

    print("Searching for %s" % qfname)
    for score in dbase.GetSortedScores(query, opts):
        print("Score for mol %u(conf %u) %f shape %f color" % (
               score.GetMolIdx(), score.GetConfIdx(),
               score.GetShapeTanimoto(), score.GetColorTanimoto()))
        dbmol = oechem.OEMol()
        molidx = score.GetMolIdx()
        if not moldb.GetMolecule(dbmol, molidx):
            print("Unable to retrieve molecule '%u' from the database" % molidx)
            continue

        mol = oechem.OEGraphMol(dbmol.GetConf(oechem.OEHasConfIdx(score.GetConfIdx())))
        oechem.OESetSDData(mol, "ShapeTanimoto", "%.4f" % score.GetShapeTanimoto())
        oechem.OESetSDData(mol, "ColorTanimoto", "%.4f" % score.GetColorTanimoto())
        oechem.OESetSDData(mol, "TanimotoCombo", "%.4f" % score.GetTanimotoCombo())
        score.Transform(mol)

        oechem.OEWriteMolecule(ofs, mol)
    print("Wrote results to %s" % argv[3])

    return 0


if __name__ == '__main__':
    sys.exit(main(sys.argv))
