/* 
(C) 2022 Cadence Design Systems, Inc. (Cadence) 
All rights reserved.
TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
provided to current licensees or subscribers of Cadence products or
SaaS offerings (each a "Customer").
Customer is hereby permitted to use, copy, and modify the Sample Code,
subject to these terms. Cadence claims no rights to Customer's
modifications. Modification of Sample Code is at Customer's sole and
exclusive risk. Sample Code may require Customer to have a then
current license or subscription to the applicable Cadence offering.
THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall Cadence be
liable for any damages or liability in connection with the Sample Code
or its use.
*/
//*****************************************************************************
//* Utility to perform a matched pair analysis on a set of structures
//*  and save the index for subsequent analysis
//* ---------------------------------------------------------------------------
//* CreateMMPIndex index_mols output_index
//*
//* index_mols: filename of input molecules to analyze
//* output_index: filename of generated MMP index
//*****************************************************************************
#include <openeye.h>
#include <oesystem.h>
#include <oechem.h>
#include <oemedchem.h>
#include "CreateMMPIndex.itf"

using namespace std;
using namespace OESystem;
using namespace OEChem;
using namespace OEMedChem;


class FilterSDData
{
public:
  FilterSDData(const std::vector<std::string> &fields, bool asFloating=true)
    : _allfields(false), _clearfields(false), _asFloating(asFloating)
  {
    if ((fields.size() == 1) && fields[0] == "-ALLSD")
      _allfields = true;
    else if ((fields.size() == 1) && fields[0] == "-CLEARSD")
      _clearfields = true;
    else
      _fields = fields;
  }

  int FilterMolData(OEMolBase &mol)
  {
    if (!OEHasSDData(mol))
      return 0;

    if (_allfields)
      return -1;

    if (_clearfields)
    {
      OEClearSDData(mol);
      return 0;
    }

    if (_fields.empty())
      return -1;

    int validdata = 0;
    std::vector<std::string> deletefields;

    float fvalue;
    for (OEIter<OESDDataPair> dp = OEGetSDDataPairs(mol); dp; ++dp)
    {
      const string tag = dp->GetTag();
      if (std::find(_fields.begin(), _fields.end(), tag) ==_fields.end())
      {
        deletefields.push_back(tag);
        continue;
      }

      if (_asFloating)
      {
        if (!OEStringToNumber(OEGetSDData(mol, tag),fvalue))
        {
            OEThrow.Warning("Failed to convert %s to numeric value (%s) in %s",
                            tag.c_str(), OEGetSDData(mol, tag).c_str(), mol.GetTitle());
          deletefields.push_back(tag);
          continue;
        }
        ++validdata;
      }
    }
    if (!validdata)
      OEClearSDData(mol);
    else
    {
      for (std::vector<std::string>::const_iterator nuke = deletefields.begin(); nuke != deletefields.end(); ++nuke)
        OEDeleteSDData(mol, *nuke);
    }

    return validdata;
  }
protected:
  FilterSDData() {}

  std::vector<std::string> _fields;
  bool _allfields;
  bool _clearfields;
  bool _asFloating;
};

int main(int argc, char *argv[])
{
  OEInterface itf(InterfaceData);
  OEConfigureMatchedPairIndexOptions(itf);

  if (!OEParseCommandLine(itf, argc, argv))
    OEThrow.Fatal("Unable to interpret command line!");

  // output index file
  std::string mmpindexfile = itf.Get<string>("-output");
  if (!OEIsMatchedPairAnalyzerFileType(mmpindexfile))
    OEThrow.Fatal("Output file is not a matched pair index type - needs .mmpidx extension: %s",
                  mmpindexfile.c_str());

  // create options class with defaults
  OEMatchedPairAnalyzerOptions opts;
  // set up options from command line
  if (!OESetupMatchedPairIndexOptions(opts, itf))
    OEThrow.Fatal("Error setting matched pair indexing options!");

  // input structures to index
  oemolistream ifsindex;
  if (!ifsindex.open(itf.Get<string>("-input")))
    OEThrow.Fatal("Unable to open %s for reading", itf.Get<string>("-input").c_str());

  // get requested verbosity setting
  bool verbose = itf.Get<bool>("-verbose");
  bool vverbose = itf.Get<bool>("-vverbose");
  if (vverbose)
    verbose = vverbose;

  int maxrec = max(itf.Get<int>("-maxrec"), 0);
  int statusrec = itf.Get<int>("-status");

  if (itf.Get<bool>("-exportcompress"))
  {
    if (!opts.SetOptions(opts.GetOptions() | OEMatchedPairOptions::ExportCompression))
      OEThrow.Warning("Error enabling export compression!");
  }

  bool stripstereo = itf.Get<bool>("-stripstereo");
  bool stripsalts = itf.Get<bool>("-stripsalts");

  bool alldata = itf.Get<bool>("-allSD");
  bool cleardata = itf.Get<bool>("-clearSD");

  std::vector<std::string> keepFields;
  if (itf.Has<std::string>("-keepSD"))
  {
    for (OESystem::OEIter<const std::string> field = itf.GetList<std::string>("-keepSD"); field; ++field)
      keepFields.push_back(field);
    if (verbose && (itf.Get<bool>("-clearSD") || alldata))
      OEThrow.Info("Option -keepSD overriding -allSD, -clearSD");
    alldata = cleardata = false;
  }
  else if (cleardata)
  {
    keepFields.push_back("-CLEARSD");
    if (verbose)
      OEThrow.Info("Forced clearing of all input SD data");
    alldata = false;
  }
  else
  {
    if (verbose && !alldata)
      OEThrow.Info("No SD data handling option specified, -allSD assumed");
    keepFields.push_back("-ALLSD");
    alldata = true;
    cleardata = false;
  }

  if (verbose)
  {
    if (!opts.HasIndexableFragmentHeavyAtomRange())
      OEThrow.Info("Indexing all fragments");
    else
      OEThrow.Info("Limiting fragment cores to %.2f-%.2f%% of input molecules",
                   opts.GetIndexableFragmentRangeMin(), opts.GetIndexableFragmentRangeMax());

    if (statusrec)
      OEThrow.Info("Status output after every %d records", statusrec);

    if (maxrec)
      OEThrow.Info("Indexing a maximum of %d records", maxrec);

    if (itf.Get<bool>("-exportcompress"))
      OEThrow.Info("Removing singleton index nodes from index");

    if (stripstereo)
      OEThrow.Info("Stripping stereo");

    if (stripsalts)
      OEThrow.Info("Stripping salts");

    if (cleardata)
      OEThrow.Info("Clearing all input SD data");
    else if (alldata)
      OEThrow.Info("Retaining all input SD data");
    else if (!keepFields.empty())
    {
      std::string allfields;
      for (std::vector<std::string>::const_iterator fiter = keepFields.begin(); fiter != keepFields.end(); ++fiter)
        allfields += " " + *fiter;
      OEThrow.Info("Retaining floating point SD data fields:%s", allfields.c_str());
    }
  }

  // create indexing engine
  OEMatchedPairAnalyzer mmp(opts);

  // interpret SD fields as floating point data
  FilterSDData validdata = FilterSDData(keepFields, true);

  // add molecules to be indexed
  int record = 0;
  int unindexed = 0;
  OEGraphMol mol;
  while (OEReadMolecule(ifsindex, mol))
  {
    if (!alldata)
    {
      // filter the input molecule SD data based on allowed fields
      validdata.FilterMolData(mol);
    }

    if (stripsalts)
      OEDeleteEverythingExceptTheFirstLargestComponent(mol);

    if (stripstereo)
      OEUncolorMol(mol,
                   (OEUncolorStrategy::RemoveAtomStereo |
                    OEUncolorStrategy::RemoveBondStereo |
                    OEUncolorStrategy::RemoveGroupStereo));

    int status = mmp.AddMol(mol, record);
    if (status != record)
    {
      ++unindexed;
      if (vverbose)
        OEThrow.Info("Input structure not added to index, record=%d status=%s",
                     record, OEMatchedPairIndexStatusName(status));
    }
    ++record;
    if (maxrec && record >= maxrec)
      break;

    if (statusrec && (record % statusrec) == 0)
      OEThrow.Info("Records: %d Indexed: %d Unindexed: %d",
                   record, (record - unindexed), unindexed);
  }

  if (!mmp.NumMols())
    OEThrow.Fatal("No records in index structure file");

  if (!mmp.NumMatchedPairs())
    OEThrow.Fatal("No matched pairs found from indexing, use -fragGe,-fragLe options to extend indexing range");

  if (!OEWriteMatchedPairAnalyzer(mmpindexfile, mmp))
    OEThrow.Fatal("Error serializing MMP index: %s", mmpindexfile.c_str());

  // return some status information
  OEThrow.Info("Records: %d, Indexed: %d, matched pairs: %d",
               record, mmp.NumMols(), mmp.NumMatchedPairs());
  return 0;
}
