Cube Unit Tests

It is good practice to write unit tests for all cubes. These tests can often quickly find trivial errors and improve development speed. The follow is an example of using the Floe CubeTestRunner.

Molecular Weight Cube Test

from os import path
from unittest import TestCase

from openeye import oeshape

from floe.test import CubeTestRunner

# Utility for converting a molecule file to records
from drconvert import MolFileConverter
from floe.api.parameters import BooleanParameter
from floe.api import ComputeCube, StringParameter
from orionplatform.mixins import RecordPortsMixin
from orionplatform.parameters import FloatFieldParameter, PrimaryMolFieldParameter

# Note: oechem must be imported before OpenEye toolkits
from openeye.oechem import OECalculateMolecularWeight


# Example cube to be tested, copied from snowball/oechem/molweight.py
class ExampleMolWeightCube(RecordPortsMixin, ComputeCube):

    # title, description, classification, etc omitted for brevity

    in_mol_field = PrimaryMolFieldParameter("in_mol_field", read_only=True)

    mwfield = FloatFieldParameter(
        "mwfield",
        default="Molecule Weight",
        required=True,
        title="Molecular Weight Field",
        description="The tag name of the molecule weight field.",
    )

    isotope = BooleanParameter(
        "isotope",
        title="Isotope Weight",
        default=False,
        description="This parameter determines whether to consider "
        "isotopic data when calculating molecule weight.",
    )

    def process(self, record, port):
        if not record.has_value(self.args.in_mol_field):
            self.failure.emit(record)
            return
        mol = record.get_value(self.args.in_mol_field)
        record.set_value(
            self.args.mwfield, OECalculateMolecularWeight(mol, self.args.isotope)
        )
        self.success.emit(record)


class MolecularWeightTestCase(TestCase):
    mol_file = path.join(path.dirname(__file__), "test_data", "100.ism")

    def test_mw_correct(self):
        """Tests that molecular weight is calculated correctly on each record"""
        cube = ExampleMolWeightCube("cube")
        test_input_records = list(MolFileConverter(self.mol_file))

        with CubeTestRunner(cube) as runner:
            runner.send_inputs(intake=test_input_records)
            self.assertEqual(
                runner.outputs["success"].emit_count, len(test_input_records)
            )
            self.assertEqual(runner.outputs["failure"].emit_count, 0)

            in_mol_field = cube.args.in_mol_field
            mwfield = cube.args.mwfield

            # Check the output records
            for record in runner.outputs["success"]:
                self.assertTrue(
                    record.has_value(mwfield),
                    "Molecular weight field missing from output",
                )
                self.assertTrue(
                    record.has_value(in_mol_field),
                    "Output record missing molecule field",
                )
                self.assertEqual(
                    record.get_value(mwfield),
                    OECalculateMolecularWeight(record.get_value(in_mol_field)),
                    "Mismatched Molecular weight",
                )


class ShapeFuncParam(StringParameter):

    """This example shows how to implement a string parameter that maps to a more complex object"""

    shape_choices = {
        "Exact": oeshape.OEExactShapeFunc(),
        "Grid": oeshape.OEGridShapeFunc(),
        "Analytic": oeshape.OEAnalyticShapeFunc(),
    }

    def __init__(self, name):
        super().__init__(
            name,
            required=False,
            default="Grid",
            choices=list(self.shape_choices.keys()),
            title="Shape type",
            description="Type of function to be used for shape overlap evaluation",
        )

    def get_runtime_value(self, value):
        """Convert the string parameter's string value to a oeshape.ShapeFunc
        Note: If the provided value is not in the specified choices then floe's parameter
              check will raise an exception *before* this method is called.
        """
        return self.shape_choices[value]


class ExampleOverlayCube(RecordPortsMixin, ComputeCube):

    # title, description, classification, etc omitted for brevity

    shape_type = ShapeFuncParam("shape_type")

    def begin(self):
        opts = oeshape.OEOverlayOptions()
        func = oeshape.OEOverlapFunc(self.args.shape_type)
        opts.SetOverlapFunc(func)

    def process(self, record, port):
        # omitted for brevity
        pass


class ExampleOverlayCubeTestCase(TestCase):
    def test_shape_type_parameter(self):
        cube = ExampleOverlayCube("cube")
        """Tests each choice for the shape_type parameter"""
        for shape_choice, shape_func in ShapeFuncParam.shape_choices.items():
            with CubeTestRunner(cube, parameters={"shape_type": shape_choice}):
                self.assertEqual(cube.args.shape_type, shape_func)

        # ensure that an invalid choice results in an exception
        with self.assertRaises(ValueError) as context:
            cube.set_parameters(shape_type="triangle")
            self.assertIn("not in choices", str(context.exception))

    def test_shape_type_promoted(self):
        """Tests the case of promoting shape_type"""
        # test parameter promotion
        cube = ExampleOverlayCube("cube")
        cube.promote_parameter("shape_type", promoted_name="st_promoted")

        with CubeTestRunner(cube, parameters={"st_promoted": "Exact"}):
            self.assertEqual(
                cube.args.shape_type,
                ExampleOverlayCube.shape_type.shape_choices["Exact"],
            )