#!/usr/bin/env python
# (C) 2022 Cadence Design Systems, Inc. (Cadence) 
# All rights reserved.
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of Cadence products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. Cadence claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable Cadence offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall Cadence be
# liable for any damages or liability in connection with the Sample Code
# or its use.

#############################################################################
# Script to prepare proteins into design units
#############################################################################
import sys
import os
from openeye import oechem
from openeye import oegrid
from openeye import oespruce


def main(argv=sys.argv):

    if len(argv) < 2 or len(argv) > 4:
        oechem.OEThrow.Usage("%s <infile> [<mtzfile>] [<loopdbfile>]" % argv[0])

    ifs = oechem.oemolistream()
    ifile = argv[1]
    if not ifs.open(ifile):
        oechem.OEThrow.Fatal("Unable to open %s for reading" % ifile)

    include_loop = False
    include_ed = False
    ed = oegrid.OESkewGrid()

    if len(argv) > 2:
        if len(argv) == 4 or (len(argv) == 3 and "mtz" in argv[2]):
            edfile = argv[2]
            if not oegrid.OEReadMTZ(edfile, ed, oegrid.OEMTZMapType_Fwt):
                oechem.OEThrow.Fatal(
                    "Unable to read electron density file %s" % edfile
                )  # noqa
            include_ed = True
        if len(argv) == 4:
            loopfile = argv[3]
            include_loop = True
        elif len(argv) == 3 and "mtz" not in argv[2]:
            loopfile = argv[2]
            include_loop = True

    if ifs.GetFormat() not in [oechem.OEFormat_PDB, oechem.OEFormat_CIF]:
        oechem.OEThrow.Fatal("Only works for .pdb or .cif input files")

    ifs.SetFlavor(
        oechem.OEFormat_PDB,
        oechem.OEIFlavor_PDB_Default
        | oechem.OEIFlavor_PDB_DATA
        | oechem.OEIFlavor_PDB_ALTLOC,
    )  # noqa

    mol = oechem.OEGraphMol()
    if not oechem.OEReadMolecule(ifs, mol):
        oechem.OEThrow.Fatal("Unable to read molecule from %s" % ifile)

    allow_filter_errors = False
    metadata = oespruce.OEStructureMetadata()
    filter_opts = oespruce.OESpruceFilterOptions()
    makedu_opts = oespruce.OEMakeDesignUnitOptions()
    makedu_opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetBuildTails(False)
    if include_loop:
        makedu_opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetLoopDBFilename(
            loopfile
        )
    
    filter = oespruce.OESpruceFilter(filter_opts, makedu_opts)
    ret_filter = filter.StandardizeAndFilter(mol, ed, metadata)
    if ret_filter !=oespruce.OESpruceFilterIssueCodes_Success:
        oechem.OEThrow.Warning("This structure fails spruce filter due to: ")
        oechem.OEThrow.Warning(filter.GetMessages())
        if not allow_filter_errors:
            oechem.OEThrow.Fatal("This structure fails spruce filter")

    if include_ed:
        design_units = oespruce.OEMakeDesignUnits(mol, ed, metadata, makedu_opts)
    else:
        design_units = oespruce.OEMakeDesignUnits(mol, metadata, makedu_opts)

    validator = oespruce.OEValidateDesignUnit()

    base_name = os.path.basename(ifile)[:-4] + "_DU_{}.oedu"
    for i, design_unit in enumerate(design_units):
        ret_validator = validator.Validate(design_unit,metadata)

        if ret_validator != oespruce.OEDesignUnitIssueCodes_Success:
            oechem.OEThrow.Warning("This generated DU did not pass DU validator.")
            oechem.OEThrow.Warning(validator.GetMessages())
        oechem.OEWriteDesignUnit(base_name.format(i), design_unit)


if __name__ == "__main__":
    sys.exit(main(sys.argv))
