"""
usage:
    load_all_experiments.py [options] PDBDIR [--density DIR] [--biounit BU_DIR]

arguments:
    PDBDIR      Directory with main structure files
    DIR         Directory with MTZ/MAP files
    BU_DUR      Directory with Spruce biounit DU files

options:
    -v                    Verbose output
    -d                    Debug mode, no multiprocessing
    --ncpu N              Number of CPUs to use [default: 0]
"""
import attr
import docopt
import multiprocessing
import json
import os
import re
import sys
from yaspin import yaspin

import spruce
from spruce.utils.progress import OEPyProgress
from spruce.utils.file_utils import get_pdb_files

import mmdsclient as mmdscli
from oeclient_utils.exceptions import BadResponse, ValidationError
from openeye import oespruce

def _find_biounit(pdbcode, directory):
    valid = re.compile(r'{}.*__DU__.*biounit.oedu'.format(pdbcode), re.IGNORECASE)
    results = spruce.utils.file_utils.get_files_from_regex(directory, valid)
    if len(results) > 0:
        return results


def get_biounits_for_code(pdbcode, pdbfile, bu_dir=None):
    if bu_dir:
        bu_list = _find_biounit(pdbcode, bu_dir)
        if bu_list:
            return bu_list

        if len(pdbcode) == 4:
            subdir = pdbcode[1:3]
            bu_list = _find_biounit(pdbcode, os.path.join(bu_dir, subdir.lower()))
            if bu_list:
                return bu_list

            bu_list = _find_biounit(pdbcode, os.path.join(bu_dir, subdir.upper()))
            if bu_list:
                return bu_list

    # look for file next to PDB file
    if bu_dir is None:
        dirname = os.path.dirname(pdbfile)
        bu_list = _find_biounit(pdbcode, dirname)
        if bu_list:
            return bu_list

    # look in the current directory
    bu_list = _find_biounit(pdbcode, '.')
    if bu_list:
        return bu_list
    return None


def get_biounits(directory):
    valid = re.compile(r'.*__DU__.*biounit.oedu', re.IGNORECASE)
    return spruce.utils.file_utils.get_files_from_regex(directory, valid)


def load_one_experiment(expt):
    session = mmdscli.get_session(profile=os.environ["MMDS_PROFILE"])

    # make sure experiment isn't already loaded
    filters = {'code': expt.code}
    try:
        for expt in session.list_resources(mmdscli.Experiment, filters=filters):
            return expt.code, expt.id, "experiment already exists"
    except BadResponse as e:
        return expt.code, None, str(e)

    json_data = open(expt.meta).read()
    meta = oespruce.OEStructureMetadata()
    if not oespruce.OEStructureMetadataFromJson(meta, json_data):
        return expt.code, None, f"Unable to read meta data from {expt.meta}"

    try:
        result = mmdscli.Experiment.add(session, expt.code, expt.structure, meta,
                                        density=expt.density, biounits=expt.biounits)
        return expt.code, result.id, "Experiment {} added".format(result.id)
    except ValidationError as e:
        return expt.code, None, str(e)
    except BadResponse as e:
        return expt.code, None, str(e)


def dump_logs(filename, out, err, log=sys.stdout):
    log.write('--------------------------------------------------------------------------\n')
    log.write("PDB file: {}\n".format(filename))
    log.write('-------------------------- OUT -------------------------------------------\n')
    log.write('{}\n'.format(out))
    log.write('-------------------------- ERR -------------------------------------------\n')
    log.write('{}\n'.format(err))
    log.write('--------------------------------------------------------------------------\n')


def get_current_experiments():
    msg = "Gathering existing experiments"
    with yaspin(text=msg, color="cyan") as sp:
        session = mmdscli.get_session(profile=os.environ["MMDS_PROFILE"])
        filters = {'limit': 1000, 'fields': 'id,code'}
        existing = set()
        try:
            for i, expt in enumerate(session.list_resources(mmdscli.Experiment, filters=filters)):
                existing.add(expt.code)
                if i % 1000:
                    sp.text = "{} : {}".format(msg, i+1)
        except BadResponse as e:
            print(str(e))
        sp.ok("✔")
        return existing


@attr.s
class Expt:
    code = attr.ib(default=None, type=str)
    structure = attr.ib(default=None, type=str)
    meta = attr.ib(default=None, type=str)
    density = attr.ib(default=None, type=str)
    biounits = attr.ib(default=None, type=[])


def gather_experiments(pdb_dir, density_dir=None, spruce_dir=None):
    if density_dir is None:
        density_dir = pdb_dir
    if spruce_dir is None:
        spruce_dir = pdb_dir

    experiments = {}

    all_pdb_files = get_pdb_files(pdb_dir)
    print('Found {} input structure files'.format(len(all_pdb_files)))
    for filename in all_pdb_files:
        basename = os.path.basename(filename)
        code = str(basename.split('.')[0].replace('pdb', '')).upper()
        meta = filename + ".json"
        if os.path.isfile(meta):
            experiments[code] = Expt(code=code, structure=filename, meta=meta, biounits=[])
        else:
            print("No meta data for experiment: {}".format(code))

    all_density_files = spruce.utils.file_utils.get_mtz_files(density_dir)
    print('Found {} input density files'.format(len(all_density_files)))
    for filename in all_density_files:
        basename = os.path.basename(filename)
        code = str(basename.split('.')[0].upper())
        if code in experiments:
            experiments[code].density = filename
        else:
            print("Found density {} without a structure file.".format(basename))

    all_bu_files = get_biounits(spruce_dir)
    print('Found {} input biounit files'.format(len(all_bu_files)))
    for filename in all_bu_files:
        basename = os.path.basename(filename)
        code = str(basename.split('_')[0].upper())
        if code in experiments:
            experiments[code].biounits.append(filename)
        else:
            print("Found biounit {} without a structure file.".format(basename))

    return experiments


def main():
    args = docopt.docopt(__doc__)
    verbose = args['-v']
    debug = args['-d']
    ncpu = int(args['--ncpu'])
    if ncpu == 0:
        ncpu = multiprocessing.cpu_count()
    if ncpu > 8:
        ncpu = 8

    # print(args)

    experiments = gather_experiments(args['PDBDIR'], args['DIR'], args['BU_DIR'])
    print('Found {} input experiments'.format(len(experiments)))

    existing = get_current_experiments()
    print("Found {} existing experiments.".format(len(existing)))

    missing_experiments = []
    for code in experiments:
        if code in existing:
            continue
        if len(experiments[code].biounits) == 0:
            print('No biounits for {}, skipping'.format(code))
            continue
        missing_experiments.append(experiments[code])

    print('Found {} new experiments to load'.format(len(missing_experiments)))

    progress = OEPyProgress(hide=verbose)
    progress.set_task("loading {} experiments".format(len(missing_experiments)))
    count = 0
    logfile = open('load_all_experiments.log', 'w')
    with progress:
        if not debug:
            pool = multiprocessing.Pool(ncpu)
            for code, out, err in pool.imap_unordered(load_one_experiment, missing_experiments):
                count += 1
                progress.set_progress(count / len(missing_experiments))
                dump_logs(code, out, err, log=logfile)
            pool.terminate()
        else:
            for expt in missing_experiments:
                count += 1
                filename, out, err = load_one_experiment(expt)
                progress.set_progress(count / len(missing_experiments))
                dump_logs(expt.structure, out, err)


if __name__ == "__main__":
    sys.exit(main())
