#!/usr/bin/env python
# (C) 2022 Cadence Design Systems, Inc. (Cadence) 
# All rights reserved.
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of Cadence products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. Cadence claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable Cadence offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall Cadence be
# liable for any damages or liability in connection with the Sample Code
# or its use.


'''
This script prepares an MD or docking-ready design unit with receptor from an input PDB or mmCIF file.

Input parameters:
    Required:
        -in: The input PDB or mmCIF file
        -site_residue: Defines a binding site residue for pocket detection (Required for apo structures only)

    Optional:
        -map: Input mtz file containing the electron density map
        -loop_db: Input loop_db file for loop modeling
        -generate_tautomers: generate and use tautomers in the hydrogen network optimization
        -prefix: String to prepend to all output DU files
        -metadata: Input metadata JSON file
        -allow_filter_error: Run spruce prep even when structure fails spruce filter
        -verbose: boolean flag to trigger verbose logging

Usage examples:
python spruce_prep.py -in 3fly.cif
python spruce_prep.py -in 3fly.pdb -metadata 3fly_metadata.json -generate_tautomers false
python spruce_prep.py -in 3fly.pdb -map 3fly.mtz -verbose true -prefix 3FLY -loop_db my_loop_db.loop_db
python spruce_prep.py -3p2q.pdb -site_residue 'HIS: 76: :A'
'''


import sys
from openeye import oechem
from openeye import oegrid
from openeye import oespruce
from openeye import oedocking


InterfaceData = '''
!PARAMETER -in
    !TYPE string
    !REQUIRED true
    !BRIEF Input PDB/CIF file name
!END

!PARAMETER -site_residue
    !TYPE string
    !REQUIRED false
    !BRIEF Site residue specification to indentify binding site (ex: 'HIS:42: :A')
!END

!PARAMETER -prefix
    !TYPE string
    !REQUIRED false
    !BRIEF prefix to append to all output DU file names
!END

!PARAMETER -map
    !TYPE string
    !REQUIRED false
    !LEGAL_VALUE *.mtz
    !BRIEF Input electron density file 
!END

!PARAMETER -loop_db
    !TYPE string
    !REQUIRED false
    !LEGAL_VALUE *.loop_db
    !BRIEF Input database for loop modeling
!END

!PARAMETER -generate_tautomers
    !TYPE bool
    !REQUIRED false
    !DEFAULT true
    !BRIEF Option to generate and use tautomers in the hydrogen network optimization (optional)
!END

!PARAMETER -metadata
    !TYPE string
    !REQUIRED false
    !LEGAL_VALUE *.json
    !BRIEF Input structure metadata json file
!END

!PARAMETER -allow_filter_error
    !TYPE bool
    !REQUIRED false
    !DEFAULT false
    !BRIEF Option to allow running spruce prep even when structure fails spruce filter.
!END

!PARAMETER -verbose
    !TYPE bool
    !REQUIRED false
    !DEFAULT false
    !BRIEF Boolean flag to trigger verbose logging
!END
'''


def main(argv=sys.argv):
    itf = oechem.OEInterface(InterfaceData, argv)

    # read input parameters
    ifile = itf.GetString('-in')

    include_ed = False
    if itf.HasString('-map'):
        mapfile = itf.GetString('-map')
        include_ed = True

    include_loop = False
    if itf.HasString('-loopdb'):
        loopfile = itf.GetString('-loopdb')
        include_loop = True

    site_residue_specified = False
    if itf.HasString('-site_residue'):
        site_residue = itf.GetString('-site_residue')
        site_residue_specified = True

    has_prefix = False
    if itf.HasString('-prefix'):
        prefix = itf.GetString('-prefix')
        has_prefix = True

    if itf.GetBool('-verbose'):
        oechem.OEThrow.SetLevel(oechem.OEErrorLevel_Verbose)

    has_metadata = False
    if itf.HasString('-metadata'):
        metadata_json_name = itf.GetString('-metadata')
        with open(metadata_json_name, "r") as f:
            metadata_json = f.read()
        has_metadata = True

    allow_filter_error = itf.GetBool('-allow_filter_error')

    generate_tautomers = itf.GetBool('-generate_tautomers')

    # read PDB or CIF input file
    ifs = oechem.oemolistream()
    if not ifs.open(ifile):
        oechem.OEThrow.Fatal(f'Unable to open {ifile} for reading')

    if ifs.GetFormat() not in [oechem.OEFormat_PDB, oechem.OEFormat_CIF]:
        oechem.OEThrow.Fatal('Input file must be .pdb or .cif')

    ifs.SetFlavor(oechem.OEFormat_PDB,
                  oechem.OEIFlavor_PDB_Default |
                  oechem.OEIFlavor_PDB_DATA |
                  oechem.OEIFlavor_PDB_ALTLOC)

    mol = oechem.OEGraphMol()
    if not oechem.OEReadMolecule(ifs, mol):
        oechem.OEThrow.Fatal(f'Unable to read molecule from {ifile}')

    # read mtz file if included
    ed = oegrid.OESkewGrid()
    if include_ed:
        if not oegrid.OEReadMTZ(mapfile, ed, oegrid.OEMTZMapType_Fwt):
            oechem.OEThrow.Fatal(f'Unable to read electron density file {mapfile}')

    makedu_opts = oespruce.OEMakeDesignUnitOptions()

    makedu_opts.GetSplitOptions().SetAlternateLocationHandling(oespruce.OEAlternateLocationOption_Combinatorial)
    makedu_opts.GetSplitOptions().SetMinLigAtoms(8)
    makedu_opts.GetSplitOptions().SetMaxLigAtoms(200)
    makedu_opts.GetSplitOptions().SetMaxLigResidues(20)

    makedu_opts.GetPrepOptions().GetEnumerateSitesOptions().SetEnumerateCofactorSites(False)
    makedu_opts.GetPrepOptions().GetEnumerateSitesOptions().SetCollapseNonSiteAlts(False)

    # read metadata file if provided
    metadata = oespruce.OEStructureMetadata()
    if has_metadata:
        oespruce.OEStructureMetadataFromJson(metadata, metadata_json)

    # set loop database if included
    if include_loop:
        makedu_opts.GetPrepOptions().GetBuildOptions().GetLoopBuilderOptions().SetLoopDBFilename(loopfile)

    # set tautomer generation flag
    makedu_opts.GetPrepOptions().GetProtonateOptions().SetGenerateTautomers(generate_tautomers)

    # run Spruce filter
    filter_opts = oespruce.OESpruceFilterOptions()
    filter = oespruce.OESpruceFilter(filter_opts, makedu_opts)
    ret_filter = filter.StandardizeAndFilter(mol, ed, metadata)

    if ret_filter !=oespruce.OESpruceFilterIssueCodes_Success:
        oechem.OEThrow.Warning(f'This structure fails spruce filter due to: ')
        oechem.OEThrow.Warning(filter.GetMessages())
        if not allow_filter_error:
            oechem.OEThrow.Fatal('This structure fails spruce filter')

    # make the DUs
    if site_residue_specified:
        # use site residue
        if include_ed:
            design_units = oespruce.OEMakeDesignUnits(mol, ed, metadata, makedu_opts, site_residue)
        else:
            design_units = oespruce.OEMakeDesignUnits(mol, metadata, makedu_opts, site_residue)
    else:
        # assume structure has bound ligand
        if include_ed:
            design_units = oespruce.OEMakeDesignUnits(mol, ed, metadata, makedu_opts)
        else:
            design_units = oespruce.OEMakeDesignUnits(mol, metadata, makedu_opts)

    # validate the DUs
    validator = oespruce.OEValidateDesignUnit()
    for i, design_unit in enumerate(design_units):
        ret_validator = validator.Validate(design_unit, metadata)

        if ret_validator != oespruce.OEDesignUnitIssueCodes_Success:
            oechem.OEThrow.Warning(f'Design unit {design_unit.GetTitle()} did not pass the DU validator.')
            oechem.OEThrow.Warning(validator.GetMessages())

        # make the receptor
        ropts = oedocking.OEMakeReceptorOptions()
        if not oedocking.OEMakeReceptor(design_unit, ropts):
            oechem.OEThrow.Warning(f'Unable to generate receptor for design unit {design_unit.GetTitle()}')

        # write the DU
        print(design_unit.GetTitle())
        basename = f'{design_unit.GetTitle()}'.replace('(', '_').replace(')', '_').replace(' > ', 'DU_').replace(' ', '_').replace('/', '-')[:-1]
        if has_prefix:
            ofile = f'{prefix}_{basename}.oedu'
        else:
            ofile = f'{basename}.oedu'

        if not oechem.OEWriteDesignUnit(ofile, design_unit):
            oechem.OEThrow.Warning(f'Unable to write design unit {design_unit.GetTitle()}')


if __name__ == "__main__":
    sys.exit(main(sys.argv))
