# (C) 2022 Cadence Design Systems, Inc. (Cadence) 
# All rights reserved.
# TERMS FOR USE OF SAMPLE CODE The software below ("Sample Code") is
# provided to current licensees or subscribers of Cadence products or
# SaaS offerings (each a "Customer").
# Customer is hereby permitted to use, copy, and modify the Sample Code,
# subject to these terms. Cadence claims no rights to Customer's
# modifications. Modification of Sample Code is at Customer's sole and
# exclusive risk. Sample Code may require Customer to have a then
# current license or subscription to the applicable Cadence offering.
# THE SAMPLE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED.  OPENEYE DISCLAIMS ALL WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. In no event shall Cadence be
# liable for any damages or liability in connection with the Sample Code
# or its use.

from uuid import uuid4
from base64 import b64encode, b64decode
from os.path import isfile
from os import remove
from io import BytesIO
from string import (ascii_uppercase,
                    ascii_lowercase,
                    digits,
                    )
from random import choices, uniform

from openeye.oechem import (OERecordToBytes,
                            OEReadRecords,
                            OESmilesToMol,
                            OEWallTimer,
                            OEMolRecord,
                            oeofstream,
                            OERecord,
                            OEField,
                            Types,
                            OEMol,
                            )
from cubeextras.utils import (RecordShardMetadata,
                              OEDBFormatHandler,
                              tempdir_filename,
                              tempdir_full,
                              get_extension,
                              FormatTag,
                              )
from floe.api import (IntegerParameter,
                      ParallelMixin,
                      ComputeCube,
                      SourceCube,
                      )
from orionplatform.mixins import (RecordPortsMixin,
                                  ShardPortsMixin,
                                  )
from orionplatform.ports import (RecordInputPort,
                                 RecordOutputPort,
                                 ShardInputPort,
                                 ShardOutputPort,
                                 CollectionInputPortV2,
                                 )
from orionclient.exceptions import OrionError
from orionclient import Shard
from drconvert import get_converter


class RandomMolRecordGeneratorCube(SourceCube):
    """
        This cube emits a specified number of
        records with aspirin molecules in them
        and random molecule names, molecular
        weights and a random boolean value
    """
    title = "Random MolRecord Generator"
    description = "A cube that generates <record_number> MolRecords with random attributes"
    classification = [["Educational", "Example", "Cube Development", "Generator", "Record", "Molecule"]]
    tags = ["Educational", "Example", "Cube Development", "Generator", "Record", "Molecule"]
    record_number = IntegerParameter("record_number",
                                     required=True,
                                     title="Number of records",
                                     description="Number of random records to generate",
                                     default=1000,
                                     min_value=1,
                                     )
    success = RecordOutputPort("success")
    random_name_field = OEField("random_name", Types.String)
    random_weight_field = OEField("random_float", Types.Float)
    random_lead_like_field = OEField("random_bool", Types.Bool)

    # Makes a record with an aspirin molecule and a random name, molecular weight and lead like property
    def make_random_record(self):
        random_name = ''.join(choices(ascii_uppercase + ascii_lowercase + digits, k=10))
        random_lead_like = choices([True, False])[0]
        random_weight = round(uniform(10, 100), 2)
        aspirin_smiles = "CC(=O)Oc1ccccc1C(O)=O"
        mol = OEMol()
        OESmilesToMol(mol, aspirin_smiles)
        record = OEMolRecord()
        record.set_mol(mol)
        record.set_value(self.random_name_field,
                         random_name,
                         )
        record.set_value(self.random_weight_field,
                         random_weight,
                         )
        record.set_value(self.random_lead_like_field,
                         random_lead_like,
                         )
        return record

    def __iter__(self):
        for n in range(self.args.record_number):
            record = self.make_random_record()
            yield record


class LogWriterCube(RecordPortsMixin, ComputeCube):
    """
        This cube simply prints the "random_number"
        field for incoming records.
    """
    title = "Log Writer Cube"
    classification = [["Examples"]]
    tags = ["Example", "I didn't edit the tags"]
    description = "A cube that prints random_number field from a record"
    verbose = True

    def process(self, record, port):
        if self.verbose:
            print(record.get_value(OEField("random_name", Types.String)))
        self.success.emit(record)


class ParallelLogWriterCube(ParallelMixin, LogWriterCube):
    title = "Parallel " + LogWriterCube.title


class _BaseBatchCube(ComputeCube, RecordPortsMixin):
    """
        This cube is an example of a base cube for
        making record batches. In this implementation
        a batch is a record with three fields in it:

        1. The "buffer" field contains a string which
           contains the records in the batch encoded
           as binary strings.
        2. The "count" field contains an integer that
           shows the number of records contains in the
           batch string.
        3. The "size" field contains an integer that
           shows the total size (in MB) of the records
           in the batch string.

        A batch is emitted after a specific number of
        records have been accumulated (specified by the
        cube parameter records_per_batch) or until the
        size of the batch reaches a specified limit
        (specified by the batch_size_limit parameter).
    """
    title = "Base Batch Cube"
    description = "A Base Class for batches"
    classification = [["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]]
    tags = ["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]

    buffer = OEField("buffer", Types.String)
    count = OEField("count", Types.Int)
    size = OEField("size", Types.Int)

    batch = OERecord()
    batch_string = b""
    batch_count = 0

    verbose = False

    # Encode a batch as a byte 64 string
    @staticmethod
    def encode_batch(batch):
        return b64encode(batch).decode("utf-8")

    # Decode a byte 64 encoded string into a binary record batch
    @staticmethod
    def decode_batch(encoded_batch):
        return b64decode(encoded_batch.encode("utf-8"))

    # Initialize Batch by emptying the batch record, batch binary string and batch count
    def batch_init(self):
        self.batch.set_value(self.buffer, "")
        self.batch.set_value(self.count, 0)
        self.batch.set_value(self.size, 0)
        self.batch_string = b""
        self.batch_count = 0

    # Determine if a batch is ready to be emitted
    def batch_ready(self):
        record_num_exceeded = self.args.records_per_batch and self.batch_count >= self.args.records_per_batch
        chunk_size_exceeded = len(self.batch_string) >= self.args.batch_size_limit
        return record_num_exceeded or chunk_size_exceeded

    # Convert a batch to records
    def batch_to_records(self, batch):
        obj = BytesIO()
        obj.write(self.decode_batch(batch.get_value(self.buffer)))
        obj.seek(0)
        records = list(OEReadRecords(obj))
        obj.close()
        return records

    # Add a record to a batch
    def batch_add_record(self, record):
        self.batch_string += OERecordToBytes(record)
        self.batch_count += 1

    # Add a batch of records to a batch
    def batch_add_batch(self, batch):
        for record in self.batch_to_records(batch):
            self.batch_string += OERecordToBytes(record)
            self.batch_count += 1
            if self.batch_ready():
                self.batch_emit()

    # Emit a batch
    def batch_emit(self):
        if self.batch_count == 0:
            return

        size = len(self.batch_string)
        encoded = self.encode_batch(self.batch_string)
        self.batch.set_value(self.size, size)
        self.batch.set_value(self.buffer, encoded)
        self.batch.set_value(self.count, self.batch_count)

        self.success.emit(self.batch)
        if self.verbose:
            self.log.info("Created Batch (size=%s, records=%s) " % (size,
                                                                    self.batch_count,
                                                                    )
                          )
        self.batch_init()


class _BaseShardCube(ComputeCube, ShardPortsMixin):
    """
        This cube is an example of a base cube for
        reading and writing shards in collections.
    """
    write_attempts = IntegerParameter("write_attempts",
                                      title="Shard Write Attempts",
                                      description="Number of attempts to write a batch to a shard",
                                      min_value=1,
                                      default=10,
                                      required=False,
                                      )
    upload_attempts = IntegerParameter("upload_attempts",
                                       title="Shard Upload Attempts",
                                       description="Number of attempts to upload a shard",
                                       min_value=1,
                                       default=10,
                                       required=False,
                                       )

    collection = None
    shard_name = None
    shard_file = None
    shard_id = None
    shard_size = 0
    record_num = 0
    shard_meta = RecordShardMetadata.create()

    buffer = OEField("buffer", Types.String)
    count = OEField("count", Types.Int)
    size = OEField("size", Types.Int)

    batch = OERecord()
    batch_string = b""
    batch_count = 0

    verbose = False

    # Get a shard description string
    def describe_shard(self, attempts=None, shard_error=None):
        description = "(id=%s records=%s size=%s)" % (self.shard_id,
                                                      self.record_num,
                                                      self.shard_size)
        if attempts is not None:
            description += "\nattempts=%s\n" % attempts
        if shard_error is not None:
            description += "\nerror=%s" % shard_error

        return description

    # Create a Shard
    def create_shard(self):
        try:
            shard = Shard.create(self.collection,
                                 name="shard_%s.%s" % (uuid4(), get_extension(self.shard_name)),
                                 metadata=self.shard_meta.metadata())
            self.shard_id = shard.id
            if self.verbose:
                self.log.info("Created Shard: %s " % self.describe_shard())
            return shard
        except OrionError as error:
            self.log.error("Failed to create shard: %s" % self.describe_shard(error))
            raise RuntimeError(error)

    # Open a Shard
    def open_shard(self):
        if self.shard_name is None:
            self.shard_name = tempdir_filename(OEDBFormatHandler().extension())
            self.shard_file = oeofstream(self.shard_name)
            if self.verbose:
                self.log.info("Opened Shard: %s " % self.describe_shard())
            return True
        self.log.info("Shard is already open: %s" % self.describe_shard())
        return False

    # Close a Shard
    def close_shard(self):
        try:
            self.shard_file.close()
            if self.verbose:
                self.log.info("Wrote Shard: %s " % self.describe_shard())
            return True
        except AttributeError as error:
            self.log.info("Failed to close shard output. %s" % self.describe_shard(error))
            return False

    # Upload a shard
    def upload_shard(self, shard):
        try:
            shard.upload_file(self.shard_name,
                              attempts=self.args.upload_attempts)
            if self.verbose:
                self.log.info("Uploaded Shard: %s " % self.describe_shard())
        except OrionError as error:
            self.log.info("Failed to upload shard: %s" % self.describe_shard(error))
            raise RuntimeError(error)

    # Emit a shard
    def emit_shard(self):
        if self.close_shard():
            shard = self.create_shard()
            self.upload_shard(shard)
            self.success.emit(shard)
            if self.verbose:
                self.log.info("Created Shard %s in collection %s %s" % (self.describe_shard(),
                                                                        self.collection.name,
                                                                        self.collection.id),
                              )
        self.shard_name = None
        self.record_num = 0

    # Read Records from Shard
    def shard_to_records(self, shard):
        if not FormatTag.has(shard):
            raise AttributeError("Shard has no format metadata {}".format(FormatTag.read(shard).format))
        if not FormatTag.has(shard):
            raise AttributeError("Shard has no size metadata {}".format(FormatTag.read(shard).size))
        if tempdir_full(1024 * 1024):
            raise RuntimeError("Cube's local file system is out of disk space.")

        temp_file = tempdir_filename(FormatTag.read(shard).format)
        process_timer = OEWallTimer()
        download_timer = OEWallTimer()
        shard.download_to_file(temp_file)
        if self.verbose:
            self.log.info("{:f} seconds to download shard {} to file {}",
                          download_timer.Elapsed(),
                          shard.id,
                          temp_file
                          )
        try:
            converter = get_converter(temp_file)
        except ValueError as e:
            self.log.error("Failed to read shard {}. {}", shard.id, str(e))
            return

        record_count = 0
        for record in converter:
            yield record
            record_count += 1

        if self.verbose:
            self.log.info("{:f} seconds to process shard {} with {} records",
                          process_timer.Elapsed(),
                          shard.id,
                          record_count,
                          )
        if isfile(temp_file):
            remove(temp_file)


class RecordsToBatchesCube(_BaseBatchCube):
    """
        This cube converts incoming records to batch
        records.

        A batch is emitted after a specific number of
        records have been accumulated (specified by the
        cube parameter records_per_batch) or until the
        size of the batch reaches a specified limit
        (specified by the batch_size_limit parameter).
    """
    title = "Records To Batches"
    description = "A cube that converts records to record batches"
    classification = [["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]]
    tags = ["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]
    verbose = True

    records_per_batch = IntegerParameter("records_per_batch",
                                         title="Records per Batch",
                                         description="Desired number of records per Batch. " +
                                                     "0 = Fill batch up to batch_size_limit.",
                                         min_value=0,
                                         required=True,
                                         )
    batch_size_limit = IntegerParameter("batch_size_limit",
                                        title="Maximum Size of Batch",
                                        description="Desired size of Batches (bytes). ",
                                        min_value=1,
                                        max_value=int(4e9),
                                        default=int(1e9),
                                        required=True,
                                        )

    # Begin by initializing batch
    def begin(self):
        self.batch_init()

    # Add records to batches and emit batch when full
    def process(self, record, port):
        self.batch_add_record(record)
        if self.batch_ready():
            self.batch_emit()

    # Emit the last batch even if not full
    def end(self):
        self.batch_emit()


class ResizeBatchesCube(RecordsToBatchesCube):
    """
        This cube converts incoming batch records to
        new batch records with a different batch size.

        A batch is emitted after a specific number of
        records have been accumulated (specified by the
        cube parameter records_per_batch) or until the
        size of the batch reaches a specified limit
        (specified by the batch_size_limit parameter).
    """
    title = "Resize Batches"
    description = "A cube that changes batch size of incoming batch records and emits resized batch records"
    classification = [["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]]
    tags = ["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]
    verbose = True

    # Add batches of records to batches and emit batch when full
    def process(self, batch, port):
        self.batch_add_batch(batch)
        if self.batch_ready():
            self.batch_emit()


class BatchesToRecordsCube(_BaseBatchCube):
    """
        This cube converts incoming batch records to
        records that are emitted.
    """
    title = "Batches to Records"
    description = "A cube that reads batches and emits their records"
    classification = [["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]]
    tags = ["Educational", "Example", "Cube Development", "I/O", "Record", "Batches"]
    verbose = True

    # Read a batch and emit its records
    def process(self, batch, port):
        for record in self.batch_to_records(batch):
            self.emit(record)


class ParallelBatchesToRecordsCube(ParallelMixin, BatchesToRecordsCube):
    title = "Parallel " + BatchesToRecordsCube.title


class BatchesToShardsCube(_BaseShardCube):
    """
        This cube converts incoming batch records to
        shards and uploads them. Each shard will
        contain the same number of records as the
        batch it was made from contained.
    """
    title = "Batches To Shards"
    description = "A cube that converts batch records to Shards in a Collection"
    classification = [["Educational", "Example", "Cube Development", "I/O", "Record", "Batches", "Collection"]]
    tags = ["Educational", "Example", "Cube Development", "I/O", "Record", "Batches", "Collection"]

    init = CollectionInputPortV2(initializer=True)
    intake = RecordInputPort()
    success = ShardOutputPort()

    # Decode an encoded batch string into a binary record batch
    @staticmethod
    def decode_batch(encoded_batch):
        return b64decode(encoded_batch.encode("utf-8"))

    # Write a batch to a shard
    def write_batch_to_shard(self, batch):

        def fail_to_write(write_attempt, write_error):
            self.log.error("Failed to write shard: %s" % self.describe_shard(write_attempt + 1, write_error))

        self.record_num = batch.get_value(self.count)
        self.shard_size = batch.get_value(self.size)

        if self.open_shard():
            for attempt in range(self.args.write_attempts):
                try:
                    batch = self.decode_batch(batch.get_value(self.buffer))
                    self.shard_file.write(batch)
                    if self.verbose:
                        self.log.info("Wrote Shard: %s " % self.describe_shard())
                    return True
                except OSError as error:
                    fail_to_write(write_attempt=attempt,
                                  write_error=error)

            fail_to_write(self.args.write_attempts,
                          write_error="")
            self.shard_file = oeofstream(self.shard_name)
        return False

    # Initialize collection
    def begin(self):
        for collection in self.init:
            if self.collection is not None:
                raise RuntimeError("Only one collection must be sent to the initializer port, " +
                                   "but multiple received. Check connections.")
            self.collection = collection

        if self.collection is None:
            raise ValueError("No collection was passed to the initialization port")

    # Write batches to shards
    def process(self, batch, port):
        if self.write_batch_to_shard(batch):
            self.emit_shard()


class ParallelBatchesToShardsCube(ParallelMixin, BatchesToShardsCube):
    title = "Parallel " + BatchesToShardsCube.title


class ShardsToBatchesCube(_BaseBatchCube, _BaseShardCube):
    """
        This cube converts incoming shards to batch
        records. Each batch will contain the same
        number of records as the shard it was made
        from contained.
    """
    title = "Shards to Batches"
    description = "A cube that reads Shards in a Collection and emits a batch record for each shard."
    classification = [["Educational", "Example", "Cube Development", "I/O", "Record", "Shard", "Collection", "Batch"]]
    tags = ["Educational", "Example", "Cube Development", "I/O", "Record", "Collection"]

    verbose = True
    intake = ShardInputPort("intake")
    success = RecordOutputPort("success")

    def begin(self):
        self.batch_init()

    def process(self, shard, port):
        for record in self.shard_to_records(shard):
            self.batch_add_record(record)
        self.batch_emit()


class ParallelShardsToBatchesCube(ParallelMixin, ShardsToBatchesCube):
    title = "Parallel " + ShardsToBatchesCube.title


class ExampleShardsToRecordsCube(_BaseShardCube):
    """
        This cube converts incoming shards to records.
    """
    title = "Shards to Records"
    description = "A cube that reads Shards in a Collection and emits their records. " \
                  "This cube has the same basic function as CollectionToRecordsCube"
    classification = [["Educational", "Example", "Cube Development", "I/O", "Record", "Shard", "Collection"]]
    tags = ["Educational", "Example", "Cube Development", "I/O", "Record", "Collection"]
    verbose = False

    success = RecordOutputPort("success")

    def process(self, shard, port):
        for record in self.shard_to_records(shard):
            self.success.emit(record)


class ParallelExampleShardsToRecordsCube(ParallelMixin, ExampleShardsToRecordsCube):
    title = "Parallel " + ExampleShardsToRecordsCube.title