/* 
(C) 2022 Cadence Design Systems, Inc. (Cadence) 
All rights reserved.
TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
provided to current licensees or subscribers of Cadence products or
SaaS offerings (each a "Customer").
Customer is hereby permitted to use, copy, and modify the Sample Code,
subject to these terms. Cadence claims no rights to Customer's
modifications. Modification of Sample Code is at Customer's sole and
exclusive risk. Sample Code may require Customer to have a then
current license or subscription to the applicable Cadence offering.
THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall Cadence be
liable for any damages or liability in connection with the Sample Code
or its use.
*/
/****************************************************************************
* Find the minimum path length between 2 smarts patterns
* or the path length between 2 named atoms
****************************************************************************/
#include <openeye.h>
#include <oesystem.h>
#include <oechem.h>
#include "minpath.itf"

using namespace std;
using namespace OESystem;
using namespace OEChem;

static void AtomPathLength(oemolistream& ifs, oemolostream& ofs, OEInterface& itf,
                           const string& atm1, const string& atm2)
{
  OEGraphMol mol;
  while (OEReadMolecule(ifs, mol))
  {
    OETriposAtomNames(mol);

    OEAtomBase* a1 = nullptr;
    OEAtomBase* a2 = nullptr;
    for (OEIter<OEAtomBase> atom = mol.GetAtoms(); atom; ++atom)
    {
      OEAtomBase *aptr = atom;
      if (aptr->GetName() == atm1)
        a1 = aptr;
      if (aptr->GetName() == atm2)
        a2 = aptr;
      if (a1 && a2)
        break;
    }

    if (!(a1 && a2))
    {
      OEThrow.Warning("Failed to find atoms %s and %s in molecule", atm1.c_str(), atm2.c_str());
      continue;
    }

    const auto pathlen = OEGetPathLength(a1, a2);
    if (itf.Get<bool>("-verbose") || !itf.Has<string>("-o"))
    {
      string smiles;
      OECreateIsoSmiString(smiles, mol);
      cout << "Path length: " << pathlen << " in " << smiles << endl;
    }

    OEIter<OEAtomBase> spath = OEShortestPath(a1, a2);
    OEGraphMol spathmol;
    const bool adjustHCount = true;
    OESubsetMol(spathmol, mol, OEIsAtomMember(spath), adjustHCount);
    string spathsmiles;
    OECreateIsoSmiString(spathsmiles, spathmol);

    if (itf.Has<std::string>("-o"))
      OEWriteMolecule(ofs, spathmol);
    else if (itf.Get<bool>("-verbose"))
      cout << spathsmiles << endl;
  }
}

static void SmartsPathLength(oemolistream& ifs, oemolostream& ofs, OEInterface& itf,
                             const OESubSearch& ss1, const OESubSearch& ss2)
{
  OEGraphMol mol;
  while (OEReadMolecule(ifs, mol))
  {
    OEPrepareSearch(mol, ss1);
    OEPrepareSearch(mol, ss2);
    if (!(ss1.SingleMatch(mol) && ss2.SingleMatch(mol)))
    {
      OEThrow.Warning("Unable to find SMARTS matches in %s, skipping", mol.GetTitle());
      continue;
    }

    vector<pair<OEAtomBase*, OEAtomBase*> > allatompairs;
    const bool unique = true;
    unsigned int allminlen = std::numeric_limits<unsigned int>::max();
    for (OEIter<const OEMatchBase> match1 = ss1.Match(mol, unique); match1; ++match1)
    {
      for (OEIter<const OEMatchBase> match2 = ss2.Match(mol, unique); match2; ++match2)
      {
        vector<pair<OEAtomBase*, OEAtomBase*> > atompairs;
        unsigned int minlen = std::numeric_limits<unsigned int>::max();
        for(OEIter<OEAtomBase> ai1 = match1->GetTargetAtoms(); ai1; ++ai1)
        {
          OEAtomBase *atom1 = ai1;
          for(OEIter<OEAtomBase> ai2 = match2->GetTargetAtoms(); ai2; ++ai2)
          {
            OEAtomBase *atom2 = ai2;
            unsigned int pathlen = OEGetPathLength(atom1, atom2);
            if (minlen > pathlen)
            {
              minlen = pathlen;
              atompairs.clear();
              atompairs.push_back(make_pair(atom1, atom2));
            }
            else if (minlen == pathlen)
              atompairs.push_back(make_pair(atom1, atom2));
          }
        }
        if (minlen < allminlen)
        {
          allminlen = minlen;
          allatompairs = atompairs;
        }
        else if (minlen == allminlen)
          allatompairs.insert(allatompairs.end(), atompairs.begin(), atompairs.end());
      }
    }

    if (itf.Get<bool>("-verbose") || !itf.Has<string>("-o"))
    {
      string smiles;
      OECreateIsoSmiString(smiles, mol);
      cout << "Shortest path length: " << allminlen << " in " << smiles << endl;
    }

    set<string> spathlist;
    pair<set<string>::iterator, bool> ret;
    for (auto const &v_i : allatompairs)
    {
      OEIter<OEAtomBase> spath = OEShortestPath(v_i.first, v_i.second);
      OEGraphMol spathmol;
      OESubsetMol(spathmol, mol, OEIsAtomMember(spath));
      string spathsmiles;
      OECreateIsoSmiString(spathsmiles, spathmol);
      ret = spathlist.insert(spathsmiles);
      if (!ret.second)
        continue;

      if (itf.Has<std::string>("-o"))
        OEWriteMolecule(ofs, spathmol);
      else if (itf.Get<bool>("-verbose"))
        cout << spathsmiles << endl;
    }
  }
}

int main(int argc, char *argv[])
{
  OEInterface itf(InterfaceData, argc, argv);

  if (!((itf.Has<std::string>("-smarts1") && itf.Has<std::string>("-smarts2"))
        ^ (itf.Has<std::string>("-atom1") && itf.Has<std::string>("-atom2"))))
    OEThrow.Fatal("-smarts1 and -smarts2 or -atom1 and -atom2 must be set");

  oemolistream ifs;
  if (!ifs.open(itf.Get<std::string>("-i")))
    OEThrow.Fatal("Unable to open %s for reading", itf.Get<std::string>("-i").c_str());

  oemolostream ofs;
  if (itf.Has<std::string>("-o"))
    if (!ofs.open(itf.Get<std::string>("-o")))
      OEThrow.Fatal("Unable to open %s for writing", itf.Get<std::string>("-o").c_str());

  if (itf.Has<std::string>("-smarts1") && itf.Has<std::string>("-smarts2"))
  {
    OESubSearch ss1;
    OESubSearch ss2;
    const std::string smarts1 = itf.Get<std::string>("-smarts1");
    if (!ss1.Init(smarts1.c_str()))
      OEThrow.Fatal("Unable to parse SMARTS1: %s", smarts1.c_str());

    const std::string smarts2 = itf.Get<std::string>("-smarts2");
    if (!ss2.Init(smarts2.c_str()))
      OEThrow.Fatal("Unable to parse SMARTS2: %s", smarts2.c_str());

    SmartsPathLength(ifs, ofs, itf, ss1, ss2);
  }
  else
  {
    const std::string atom1 = itf.Get<std::string>("-atom1");
    const std::string atom2 = itf.Get<std::string>("-atom2");
    AtomPathLength(ifs, ofs, itf, atom1, atom2);
  }

  return 0;
}
