Source code for mosdef_cassandra.writers.inp_functions

import mbuild
import datetime
import numpy as np
import unyt as u
import mosdef_cassandra
import copy


from unyt import dimensions

from mosdef_cassandra.utils.units import validate_unit, validate_unit_list


def generate_input(
    system, moveset, run_type, run_length, temperature, **kwargs
):
    """Construct an input file section by section (with defaults)

    Default options are provided based upon the mosdef_cassandra.System
    and mosdef_cassandra.MoveSet and are typically reasonable choices;
    these may or may not be good choices for your specific system. Any
    options can be overriden by specifying the relevant choices with
    keyword arguments in **kwargs.

    Parameters
    ----------
    system : mosdef_cassandra.System
        system to be simulated
    moveset : mosdef_cassandra.MoveSet
        move probabilities
    run_type : str
        'equil' or 'prod'
    run_length : float
        simulation length in units of (default=steps), or your choice
        as specified by the 'units' option in **kwargs
    temperature : unyt_array or unyt_quantity
        temperature of the system
    **kwargs : dict
        keyword arguments. Details below.

    Returns
    -------
    inp_data : str
        a string with the entire input file
    """

    # Sanity check on kwargs
    valid_args = _get_possible_kwargs()
    for arg in kwargs:
        if arg not in valid_args:
            raise ValueError(
                "Invalid input argument {}. "
                "Allowable options include {}".format(arg, valid_args)
            )

    # Check/convert temperature
    validate_unit(temperature, dimensions.temperature)
    temperature = temperature.to("kelvin")

    # Convert moveset units (the "checking" is handled by the setters)
    moveset = _convert_moveset_units(moveset)

    # Check/convert kwargs units
    _check_kwarg_units(kwargs)
    _convert_kwarg_units(kwargs)
    # Construct an input file section by section
    inp_data = """
! Generated by mosdef_cassandra version {} on {}
""".format(
        mosdef_cassandra.__version__,
        datetime.datetime.now(),
    )

    # Extract some basic info
    nbr_species = len(system.species_topologies)
    nbr_boxes = len(system.boxes)

    # Run name
    if "run_name" in kwargs:
        run_name = kwargs["run_name"]
    else:
        run_name = moveset.ensemble
    inp_data += get_run_name(run_name)

    # Verbose log
    if "verbose_log" in kwargs:
        verbose_log = kwargs["verbose_log"]
    else:
        verbose_log = False

    if verbose_log:
        inp_data += get_verbose_log(verbose_log)

    # Ensemble
    inp_data += get_sim_type(moveset.ensemble)

    # Number of species
    inp_data += get_nbr_species(nbr_species)

    # VDW Style
    # NOTE: Once more than LJ is supported (topology object)
    # this can be inferred from system.species
    if "vdw_style" in kwargs:
        vdw_style = kwargs["vdw_style"]
    else:
        vdw_style = "lj"

    if "cutoff_style" in kwargs:
        cutoff_style = kwargs["cutoff_style"]
    else:
        cutoff_style = "cut_tail"

    if "vdw_cutoff" in kwargs:
        vdw_cutoff = kwargs["vdw_cutoff"].to_value()
    else:
        vdw_cutoff = 12.0

    if vdw_style == "none":
        cutoff_style = None

    vdw_styles = [vdw_style] * nbr_boxes
    cutoff_styles = [cutoff_style] * nbr_boxes
    vdw_cutoffs = [vdw_cutoff] * nbr_boxes
    # Support for per-box cutoffs
    if "vdw_cutoff_box1" in kwargs:
        vdw_cutoffs[0] = kwargs["vdw_cutoff_box1"].to_value()
    if "vdw_cutoff_box2" in kwargs:
        if nbr_boxes == 2:
            vdw_cutoffs[1] = kwargs["vdw_cutoff_box2"].to_value()
        else:
            raise ValueError(
                "Only one box in System but "
                "cutoff for box 2 specified in kwargs"
            )

    # TODO: Check that cutoff <= half box length
    inp_data += get_vdw_style(vdw_styles, cutoff_styles, vdw_cutoffs)

    # Charge Style
    if "charge_style" in kwargs:
        charge_style = kwargs["charge_style"]
    else:
        charge_style = "ewald"

    if "charge_cutoff" in kwargs:
        charge_cutoff = kwargs["charge_cutoff"].to_value()
    else:
        charge_cutoff = 12.0

    if "ewald_accuracy" in kwargs:
        ewald_accuracy = kwargs["ewald_accuracy"]
    else:
        ewald_accuracy = 1.0e-5

    if "dsf_damping" in kwargs:
        dsf_damping = kwargs["dsf_damping"]
    else:
        dsf_damping = None

    charge_styles = [charge_style] * nbr_boxes
    charge_cutoffs = [charge_cutoff] * nbr_boxes

    # Support for per-box cutoffs
    if "charge_cutoff_box1" in kwargs:
        charge_cutoffs[0] = kwargs["charge_cutoff_box1"].to_value()
    if "charge_cutoff_box2" in kwargs:
        if nbr_boxes == 2:
            charge_cutoffs[1] = kwargs["charge_cutoff_box2"].to_value()
        else:
            raise ValueError(
                "Only one box in System but "
                "cutoff for box 2 specified in kwargs"
            )

    # TODO: Check that cutoff <= half box length
    inp_data += get_charge_style(
        charge_styles,
        charge_cutoffs,
        ewald_accuracy=ewald_accuracy,
        dsf_damping=dsf_damping,
    )

    # NOTE: In future get mixing rule from Topology
    if "mixing_rule" in kwargs:
        mixing_rule = kwargs["mixing_rule"]
    else:
        mixing_rule = "lb"

    if "custom_mixing_dict" in kwargs:
        custom_mixing_dict = kwargs["custom_mixing_dict"]
        print(
            "Warning: be very careful with custom mixing rules.\n"
            "Please check your final input file."
        )
    else:
        custom_mixing_dict = None
    inp_data += get_mixing_rule(mixing_rule, custom_mixing_dict)

    # Seeds
    if "seeds" in kwargs:
        seeds = kwargs["seeds"]
        if not isinstance(seeds, list) or len(seeds) != 2:
            raise TypeError(
                'The "seeds" argument should be a ' "list of two integers"
            )
        seed1 = kwargs["seeds"][0]
        seed2 = kwargs["seeds"][1]
    else:
        seed1 = None
        seed2 = None
    inp_data += get_seed_info(seed1, seed2)

    # Minimum cutoff
    if "rcut_min" in kwargs:
        rcut_min = kwargs["rcut_min"].to_value()
    else:
        rcut_min = 1.0
    inp_data += get_minimum_cutoff(rcut_min)

    # Pair Energy
    if "pair_energy" in kwargs:
        pair_energy = kwargs["pair_energy"]
    else:
        pair_energy = True
    inp_data += get_pair_energy(pair_energy)

    # Molecule Files
    max_molecules_dict = {
        "species%d.mcf" % (i + 1): 0 for i in range(nbr_species)
    }
    if "max_molecules" in kwargs:
        max_molecules = kwargs["max_molecules"]
        if not isinstance(max_molecules, list):
            raise TypeError(
                "Max molecules should be a list, "
                "with one integer per species"
            )
        if len(max_molecules) != nbr_species:
            raise ValueError(
                "Length of list specified with "
                '"max_molecules" ({})must be equal to the number '
                "of species in the system ({})".format(
                    len(kwargs["max_molecules"]), nbr_species
                )
            )
        for i, max_mols in enumerate(max_molecules):
            max_molecules_dict["species%d.mcf" % (i + 1)] = max_mols
    else:
        for isp in range(nbr_species):
            max_mols = 0
            for ibox in range(nbr_boxes):
                max_mols += system.mols_in_boxes[ibox][isp]
                max_mols += system.mols_to_add[ibox][isp]
            # TODO: Document/improve this
            # Memory is cheap (most of the time)
            if moveset.ensemble == "gcmc" and moveset.insertable[isp]:
                max_mols += 2000

            max_molecules_dict["species%d.mcf" % (isp + 1)] = max_mols

    inp_data += get_molecule_files(max_molecules_dict)

    # Box Info
    boxes = []

    for box in system.boxes:
        if isinstance(box, mbuild.Compound):
            box_matrix = copy.deepcopy(box.box.vectors)
        else:
            box_matrix = copy.deepcopy(box.vectors)

        box_matrix = u.unyt_array(box_matrix, "nm")
        # box_matrix = [u.unyt_array(i, "nm") for i in box_matrix]
        boxes.append(box_matrix)
    inp_data += get_box_info(
        boxes, moveset._restricted_type, moveset._restricted_value
    )

    temperatures = [temperature.to_value()] * nbr_boxes
    inp_data += get_temperature_info(temperatures)

    if moveset.ensemble == "npt" or moveset.ensemble == "gemc_npt":
        if "pressure" in kwargs:
            pressure = kwargs["pressure"]
        else:
            raise ValueError(
                "Pressure must be specified for ensemble "
                '"npt" or "gemc_npt"'
            )
        pressures = [pressure] * nbr_boxes
        if "pressure_box1" in kwargs:
            pressures[0] = kwargs["pressure_box1"]
        if "pressure_box2" in kwargs:
            pressures[1] = kwargs["pressure_box2"]

        inp_data += get_pressure_info(pressures)

    if moveset.ensemble == "gcmc":
        if "chemical_potentials" in kwargs:
            chemical_potentials = kwargs["chemical_potentials"]
        else:
            raise ValueError(
                "Chemical potential information must be "
                "specified for each species if the ensemble is "
                '"gcmc". The chemical potential of non-insertable '
                'species should be "none"'
            )
        if len(chemical_potentials) != nbr_species:
            raise ValueError(
                "Chemical potential information must be "
                "specified for each species if the ensemble is "
                '"gcmc". The chemical potential of non-insertable '
                'species should be "none"'
            )

        for isp, chempot in enumerate(chemical_potentials):
            if moveset.insertable[isp] == False and chempot != "none":
                raise ValueError(
                    "The chemical potential of non-insertable "
                    'species should be "none"'
                )

        inp_data += get_chemical_potential_info(chemical_potentials)

    # Move probability info
    move_prob_dict = {}
    if moveset.prob_translate > 0.0:
        move_prob_dict["translate"] = [
            moveset.prob_translate,
            *[
                [val.to_value() for val in box]
                for box in moveset.max_translate
            ],
        ]
    if moveset.prob_rotate > 0.0:
        move_prob_dict["rotate"] = [
            moveset.prob_rotate,
            *[[val.to_value() for val in box] for box in moveset.max_rotate],
        ]
    if moveset.prob_angle > 0.0:
        move_prob_dict["angle"] = moveset.prob_angle

    if moveset.prob_dihedral > 0.0:
        move_prob_dict["dihedral"] = [
            moveset.prob_dihedral,
            *[[val.to_value() for val in box] for box in moveset.max_dihedral],
        ]
    if moveset.prob_regrow > 0.0:
        move_prob_dict["regrow"] = [
            moveset.prob_regrow,
            moveset.prob_regrow_species,
        ]
    if moveset.prob_volume > 0.0:
        move_prob_dict["volume"] = [
            moveset.prob_volume,
            [i.to_value() for i in moveset.max_volume],
        ]
    if moveset.prob_insert > 0.0:
        move_prob_dict["insert"] = [moveset.prob_insert, moveset.insertable]
    if moveset.prob_swap > 0.0:
        move_prob_dict["swap"] = [
            moveset.prob_swap,
            moveset.insertable,
            moveset.prob_swap_species,
            moveset.prob_swap_from_box,
        ]
    if moveset._restricted_type and moveset._restricted_value:
        move_prob_dict["restricted_insertion"] = [
            moveset._restricted_type,
            moveset._restricted_value,
        ]

    inp_data += get_move_probability_info(**move_prob_dict)

    # CBMC information
    inp_data += get_cbmc_info(
        moveset.cbmc_n_insert,
        moveset.cbmc_n_dihed,
        [i.to_value() for i in moveset.cbmc_rcut],
    )

    # Start type info
    start_types = []
    for ibox, box in enumerate(system.boxes):
        if isinstance(box, mbuild.Compound):
            if sum(system.mols_to_add[ibox]) > 0:
                existing_mols = " ".join(
                    [str(x) for x in system.mols_in_boxes[ibox]]
                )
                xyz_name = "box{}.in.xyz".format(ibox + 1)
                new_mols = " ".join([str(x) for x in system.mols_to_add[ibox]])
                start_type = "add_to_config "
                start_type += existing_mols + " "
                start_type += xyz_name + " "
                start_type += new_mols
                start_types.append(start_type)
            else:
                existing_mols = " ".join(
                    [str(x) for x in system.mols_in_boxes[ibox]]
                )
                xyz_name = "box{}.in.xyz".format(ibox + 1)
                start_type = "read_config "
                start_type += existing_mols + " "
                start_type += xyz_name
                start_types.append(start_type)
        else:
            new_mols = " ".join([str(x) for x in system.mols_to_add[ibox]])
            start_type = "make_config " + new_mols
            start_types.append(start_type)

    inp_data += get_start_type(start_types)

    # Move statistics/updating
    if "thermal_stat_freq" in kwargs:
        thermal_stat_freq = kwargs["thermal_stat_freq"]
    else:
        thermal_stat_freq = 1000

    if (
        moveset.ensemble == "npt"
        or moveset.ensemble == "gemc"
        or moveset.ensemble == "gemc_npt"
    ):
        if "vol_stat_freq" in kwargs:
            vol_stat_freq = kwargs["vol_stat_freq"]
        else:
            vol_stat_freq = 100
    else:
        vol_stat_freq = None

    if run_type == "equil":
        run_type = "equilibration"
    elif run_type == "prod":
        run_type = "production"

    inp_data += get_run_type(run_type, thermal_stat_freq, vol_stat_freq)

    # Simulation length section
    if "units" in kwargs:
        units = kwargs["units"]
    else:
        units = "steps"

    if "prop_freq" in kwargs:
        prop_freq = kwargs["prop_freq"]
    else:
        prop_freq = 500

    if "coord_freq" in kwargs:
        coord_freq = kwargs["coord_freq"]
    else:
        coord_freq = 5000

    # TODO: Change units to sweeps and come up with
    # a smart way of calculating the number of sweeps
    if "steps_per_sweep" in kwargs:
        steps_per_sweep = kwargs["steps_per_sweep"]
    else:
        steps_per_sweep = None

    if "block_avg_freq" in kwargs:
        block_avg_freq = kwargs["block_avg_freq"]
    else:
        block_avg_freq = None

    inp_data += get_simulation_length_info(
        units,
        prop_freq,
        coord_freq,
        run_length,
        steps_per_sweep,
        block_avg_freq,
    )

    # Properties section
    if "properties" in kwargs:
        properties = kwargs["properties"]
    else:
        properties = [
            "energy_total",
            "energy_intra",
            "energy_inter",
            "enthalpy",
            "pressure",
            "volume",
            "nmols",
            "mass_density",
        ]

    inp_data += get_property_info(properties, nbr_boxes)

    # Empty fragment section unless restart
    fragment_files = None
    if "restart" in kwargs and kwargs["restart"]:
        old_inp = kwargs["restart_name"] + ".inp"
        fragment_files = []
        start_fragment_section = False
        with open(old_inp) as f:
            for line in f:
                if "# Fragment_Files" in line:
                    start_fragment_section = True
                    continue
                if start_fragment_section:
                    if "!--" in line:
                        break
                    fragment_files.append(line)

    inp_data += get_fragment_files(fragment_files)

    inp_data += "\nEND\n"

    return inp_data


#######################################################################
#######################################################################
#############  Functions to write individual sections of  #############
#############     the input file are below this point     #############
#######################################################################
#######################################################################


def get_run_name(name):
    """Get the Run_Name section of the input file"""

    if not isinstance(name, str):
        raise TypeError("name: {} must be a string".format(name))

    if "-" in name:
        raise ValueError(
            f"run_name: {name} may only contain '0-9', 'A-Z', 'a-z', '.', or '_' characters."
        )
    name = name.replace(" ", "_")
    inp_data = """
# Run_Name
{name}.out
!------------------------------------------------------------------------------
""".format(
        name=name
    )

    return inp_data


def get_sim_type(sim_type):
    """Get the Sim_Type section of the input file"""

    sim_types = ["nvt", "npt", "gcmc", "gemc", "gemc_npt"]
    if sim_type not in sim_types:
        raise ValueError(
            "Unsupported sim_type: {}. Supported options"
            "include {}".format(sim_type, sim_types)
        )

    inp_data = """
# Sim_Type
{sim_type}
!------------------------------------------------------------------------------
""".format(
        sim_type=sim_type
    )

    return inp_data


def get_nbr_species(nbr_species):
    """Get the Nbr_Species section of the input file"""

    if not isinstance(nbr_species, int):
        raise TypeError("nbr_species must be an int")

    inp_data = """
# Nbr_Species
{nbr_species}
!------------------------------------------------------------------------------
""".format(
        nbr_species=nbr_species
    )

    return inp_data


def get_vdw_style(vdw_styles, cut_styles, cutoffs):
    """Get the VDW_Style section of the input file

    Parameters
    ----------
    vdw_styles : list
        list of vdw_style for each box, one entry per box
    cut_styles : list
        list of cutoff_style for each box, one entry per box. For a
        box with vdw_style == 'none', the cutoff style is None
    cutoffs : list
        list with cutoffs for each box, one entry per box For a
        box with vdw_style == 'none', the cutoff is None
    """

    assert len(vdw_styles) == len(cut_styles)
    assert len(vdw_styles) == len(cutoffs)
    valid_vdw_styles = ["lj", "none"]
    valid_cut_styles = {vstyle: [] for vstyle in valid_vdw_styles}
    valid_cut_styles["lj"].append("cut")
    valid_cut_styles["lj"].append("cut_tail")
    valid_cut_styles["lj"].append("cut_switch")
    valid_cut_styles["lj"].append("cut_shift")
    valid_cut_styles["none"].append(None)
    for vdw_style in vdw_styles:
        if vdw_style not in valid_vdw_styles:
            raise ValueError(
                "Unsupported vdw_style: {}. Supported options "
                "include {}".format(vdw_style, vdw_styles)
            )
    for cut_style, vdw_style in zip(cut_styles, vdw_styles):
        if cut_style not in valid_cut_styles[vdw_style]:
            raise ValueError(
                "Unsupported cutoff style: {}. Supported "
                "options for the selected vdw_style ({}) include "
                "{}".format(cut_style, vdw_style, valid_cut_styles[vdw_style])
            )

    for cut_style, cutoff in zip(cut_styles, cutoffs):
        if cut_style == "cut_switch":
            if not isinstance(cutoff, np.ndarray) or len(cutoff) != 2:
                raise ValueError(
                    'Style "cut_switch" requires an inner '
                    "and outer cutoff. Use the "
                    "cutoffs=[inner_cut,outer_cut] "
                    "kwargs option."
                )

    inp_data = """
# VDW_Style"""

    for vdw_style, cut_style, cutoff in zip(vdw_styles, cut_styles, cutoffs):
        if vdw_style == "none":
            inp_data += """
{vdw_style}""".format(
                vdw_style=vdw_style
            )
        else:
            if cut_style == "cut_switch":
                inner_cutoff = cutoff[0]
                outer_cutoff = cutoff[1]
                inp_data += """
{vdw_style} {cut_style} {inner_cutoff} {outer_cutoff}""".format(
                    vdw_style=vdw_style,
                    cut_style=cut_style,
                    inner_cutoff=inner_cutoff,
                    outer_cutoff=outer_cutoff,
                )
            else:
                inp_data += """
{vdw_style} {cut_style} {cutoff}""".format(
                    vdw_style=vdw_style, cut_style=cut_style, cutoff=cutoff
                )
    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_charge_style(
    charge_styles, cutoffs, ewald_accuracy=None, dsf_damping=None
):
    """Get the Charge_Style section of the input file

    Parameters
    ----------
    charge_styles : list
        list of charge styles, one for each box
    cutoffs :
        list of coulombic cutoffs, one for each box. For a box with
        charge style 'none', the cutoff should be None
    ewald_accuracy : float, optional
        accuracy of ewald sum. Required if charge_style == ewald
    dsf_damping : float, optional
        value for dsf damping.
    """
    assert len(charge_styles) == len(cutoffs)
    valid_charge_styles = ["none", "cut", "ewald", "dsf"]
    for charge_style in charge_styles:
        if charge_style not in valid_charge_styles:
            raise ValueError(
                "Unsupported charge_style: {}. Supported options "
                "include {}".format(charge_style, charge_styles)
            )
        if charge_style == "ewald":
            if ewald_accuracy is None:
                raise ValueError(
                    "Ewald selected as the charge style but "
                    "no ewald accuracy provided"
                )

    inp_data = """
# Charge_Style"""

    for charge_style, cutoff in zip(charge_styles, cutoffs):
        if charge_style == "none":
            inp_data += """
{charge_style}""".format(
                charge_style=charge_style
            )

        elif charge_style == "cut":
            inp_data += """
coul {charge_style} {cutoff}""".format(
                charge_style=charge_style, cutoff=cutoff
            )

        elif charge_style == "ewald":
            inp_data += """
coul {charge_style} {cutoff} {accuracy}""".format(
                charge_style=charge_style,
                cutoff=cutoff,
                accuracy=ewald_accuracy,
            )
        elif charge_style == "dsf":
            if dsf_damping is not None:
                inp_data += """
coul {charge_style} {cutoff} {damping}""".format(
                    charge_style=charge_style,
                    cutoff=cutoff,
                    damping=dsf_damping,
                )
            else:
                inp_data += """
coul {charge_style} {cutoff}""".format(
                    charge_style=charge_style, cutoff=cutoff
                )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_mixing_rule(mixing_rule, custom_mixing_dict=None):
    valid_mixing_rules = ["lb", "geometric", "custom"]
    if mixing_rule not in valid_mixing_rules:
        raise ValueError(
            "Unsupported mixing rule: {}. Supported options "
            "include {}".format(mixing_rule, valid_mixing_rules)
        )
    if mixing_rule == "custom" and custom_mixing_dict is None:
        raise ValueError(
            "Custom mixing rule requested but no mixing "
            "parmameters provided."
        )

    inp_data = """
# Mixing_Rule
{mixing_rule}
""".format(
        mixing_rule=mixing_rule
    )

    if mixing_rule == "custom":
        for pair, parms in custom_mixing_dict.items():
            inp_data += """{pair} {parms}
""".format(
                pair=pair, parms=parms
            )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_seed_info(seed1=None, seed2=None):
    if seed1 is None:
        seed1 = np.random.randint(1, 100000000)
    if seed2 is None:
        seed2 = np.random.randint(1, 100000000)

    if seed1 < 0 or seed2 < 0 or seed1 > 100000000 or seed2 > 100000000:
        raise ValueError("Seeds must be integers between " "1 and 100000000")

    inp_data = """
# Seed_Info
{seed1} {seed2}""".format(
        seed1=seed1, seed2=seed2
    )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_minimum_cutoff(cutoff):
    if not isinstance(cutoff, (float, int)):
        raise TypeError("rcut_min should be of type float")

    inp_data = """
# Rcutoff_Low
{cutoff}""".format(
        cutoff=cutoff
    )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_pair_energy(save):

    if not isinstance(save, bool):
        raise TypeError("pair_energy must be of type boolean")

    if save:
        save = "true"
    else:
        save = "false"

    inp_data = """
# Pair_Energy
{save}""".format(
        save=save
    )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_molecule_files(max_molecules_dict):
    inp_data = """
# Molecule_Files"""

    for filename, max_mols in max_molecules_dict.items():
        inp_data += """
{filename} {max_mols}""".format(
            filename=filename, max_mols=max_mols
        )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_box_info(boxes, restricted_type, restricted_value):
    """Get the box info section of the input file

    Parameters
    ----------
    boxes : list
       list of box matrices with one box matrix
       per simulation box
    restricted_type : list
       list of restricted insertion types per species
       per simulation box
    restricted_value : list
       list of restricted insertion values corresponding
       to `restricted_type`
    """
    nbr_boxes = len(boxes)
    for box in boxes:
        # unyt array doesn't seem to support 3D arrays right now
        # so the shape has to be checked in a more roundabout way
        assert box.shape == (3, 3)
        box.convert_to_units("angstrom")
        # assert len(box) == 3
        # for dims in box:
        #    assert dims.shape == (3,)
        #    dims.convert_to_units("angstrom")

    inp_data = """
# Box_Info
{nbr_boxes}""".format(
        nbr_boxes=nbr_boxes
    )

    box_types = []
    for box in boxes:
        if (
            np.count_nonzero(
                box.to_value() - np.diag(np.diagonal(box.to_value()))
            )
            == 0
        ):
            if np.all(np.diagonal(box.to_value()) == box[0][0].to_value()):
                box_types.append("cubic")
            else:
                box_types.append("orthogonal")
        else:
            box_types.append("cell_matrix")

    # Convert boxes
    if restricted_type and restricted_value:
        for box, box_type, restrict_types, restrict_vals in zip(
            boxes, box_types, restricted_type, restricted_value
        ):
            inp_data += """
{box_type}""".format(
                box_type=box_type
            )
            if box_type == "cubic":
                inp_data += """
{dim}
""".format(
                    dim=box[0][0].to_value()
                )

            elif box_type == "orthogonal":
                inp_data += """
{dim1} {dim2} {dim3}
""".format(
                    dim1=box[0][0].to_value(),
                    dim2=box[1][1].to_value(),
                    dim3=box[2][2].to_value(),
                )

            else:
                inp_data += """
{ax} {bx} {cx}
{ay} {by} {cy}
{az} {bz} {cz}
""".format(
                    ax=box[0][0].to_value(),
                    ay=box[0][1].to_value(),
                    az=box[0][2].to_value(),
                    bx=box[1][0].to_value(),
                    by=box[1][1].to_value(),
                    bz=box[1][2].to_value(),
                    cx=box[2][0].to_value(),
                    cy=box[2][1].to_value(),
                    cz=box[2][2].to_value(),
                )

            for typ, value in zip(restrict_types, restrict_vals):
                _check_restricted_insertions(box, typ, value)
                if typ == "interface":
                    inp_data += """restricted_insertion {} {} {}
                    """.format(
                        typ,
                        value[0].to_value("angstrom"),
                        value[1].to_value("angstrom"),
                    )
                elif typ:
                    inp_data += """restricted_insertion {} {}
                    """.format(
                        typ, value.to_value("angstrom")
                    )

    else:
        for box, box_type in zip(boxes, box_types):
            inp_data += """
{box_type}""".format(
                box_type=box_type
            )
            if box_type == "cubic":
                inp_data += """
{dim}
""".format(
                    dim=box[0][0].to_value()
                )

            elif box_type == "orthogonal":
                inp_data += """
{dim1} {dim2} {dim3}
""".format(
                    dim1=box[0][0].to_value(),
                    dim2=box[1][1].to_value(),
                    dim3=box[2][2].to_value(),
                )

            else:
                inp_data += """
{ax} {bx} {cx}
{ay} {by} {cy}
{az} {bz} {cz}
""".format(
                    ax=box[0][0].to_value(),
                    ay=box[0][1].to_value(),
                    az=box[0][2].to_value(),
                    bx=box[1][0].to_value(),
                    by=box[1][1].to_value(),
                    bz=box[1][2].to_value(),
                    cx=box[2][0].to_value(),
                    cy=box[2][1].to_value(),
                    cz=box[2][2].to_value(),
                )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_temperature_info(temps):
    """Get the Temperature_Info section of the input file

    Parameters
    ----------
    temps : list
         list of temperatures with one for each box
    """
    for temp in temps:
        if not isinstance(temp, (float, int)):
            raise TypeError("Temperature must be of type float")
        if temp < 0.0:
            raise ValueError(
                "Specified temperature ({}) is " "less than zero".format(temp)
            )

    inp_data = """
# Temperature_Info"""

    for temp in temps:
        inp_data += """
{temperature}""".format(
            temperature=temp
        )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_pressure_info(pressures):
    """Get the Pressure_Info section of the input file

    Parameters
    ----------
    pressures : list
        list of pressures with one for each box
    """

    for press in pressures:
        if not isinstance(press, u.unyt_array):
            raise TypeError("Pressure must be of type `unyt_array`")

    inp_data = """
# Pressure_Info"""

    for press in pressures:
        inp_data += """
{pressure}""".format(
            pressure=press.to_value()
        )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_chemical_potential_info(chem_pots):
    """Get the Chemical_Potential_Info section of the input file

    Parameters
    ----------
    chem_pots : list
         list of chemical potentials with one for each species
         Non-insertable species should have None for the chemical potential
    """

    for chem_pot in chem_pots:
        if chem_pot != "none":
            if not isinstance(chem_pot, u.unyt_array):
                raise TypeError(
                    'Chemical potentials must "none" or '
                    "be of type `unyt_array`"
                )

    inp_data = """
# Chemical_Potential_Info
"""

    for chem_pot in chem_pots:
        if chem_pot == "none":
            inp_data += """{chem_pot} """.format(chem_pot=chem_pot)
        else:
            inp_data += """{chem_pot} """.format(chem_pot=chem_pot.to_value())

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_move_probability_info(**kwargs):
    """Get the Move_Probability_Info section of the input file

    Parameters
    ----------
    kwargs : dict
         Dictionary of move probability information. Each valid keyword and
         associated information is described below

         'trans'  : [prob, box_i, box_j, ...]
                    where prob is the overall probability of
                    selecting a translation move and box_i/j are lists
                    containing the max displacement (angstroms) for
                    each species

         'rotate' : [prob, box_i, box_j, ...]
                    where prob is the overall probability of
                    selecting a rotation move and box_i/j are lists
                    containing the max rotations (degrees) for
                    each species

         'angle' : prob
                    where prob is the overall probability of selecting
                    an angle move

         'dihed' : [prob, displacements]
                    where prob is the overall probability of selecting
                    a dihedral move and displacements is a list of the
                    maximum displacement (degrees) for dihedrals in
                    each species

         'regrow' : [prob, species_probs]
                    where prob is the overall probability of selecting
                    a regrowth move and species_probs is a list of the
                    probabilities of selecting a regrowth move for each
                    species

         'volume' : [prob, displacements]
                    where prob is the overall probability of selecting
                    a volume move and displacements is a list of the
                    max volume change for each box

         'insert' : [prob, insertable]
                    where prob is the overall probability of selecting
                    a insertion/deletion move and insertable is a list
                    of booleans indicating whether each species is
                    insertable or not

         'swap'  : [prob, insertable, prob_species, prob_from_box]
                    where prob is the overall probability of selecting
                    a swap move and insertable is a list of booleans
                    indicating whether each species is insertable or not.
                    prob_species and prob_from_box are optional and
                    should be None if they are not to be specified.
                    prob_species is a list with one value for each
                    species that determines the probability of selecting
                    that species for the swap move. prob_from_box is a
                    list with one value for each box specifying the
                    probability of using the box as a donor box.

         'restricted_insertion' : [restricted_type, restricted_value]
                                  where restricted_type is the type of
                                  restricted insertion and restricted_value
                                  is the value that corresponds to the
                                  restricted_type.  Both variables are lists
                                  that correspond to each species per box.
                                  If None, then restricted type is not specified
                                  for that particular species in the given box.

    """

    # First a sanity check on kwargs
    valid_args = [
        "translate",
        "rotate",
        "angle",
        "dihed",
        "regrow",
        "volume",
        "insert",
        "swap",
        "restricted_insertion",
    ]

    for arg in kwargs:
        if arg not in valid_args:
            raise ValueError(
                "Invalid probability info section {}. "
                "Allowable options include {}".format(arg, valid_args)
            )

    inp_data = """
# Move_Probability_Info
"""

    # Translation
    if "translate" in kwargs:
        translate = kwargs["translate"]
        if not isinstance(translate, list):
            raise TypeError(
                "Translate probability information not " "formatted properly"
            )
        if not isinstance(translate[0], (float, int)):
            raise TypeError(
                "Probability of translation move must be "
                "a floating point value"
            )
        for sp_displacements in translate[1:]:
            if not isinstance(sp_displacements, list):
                raise TypeError(
                    "Translate probability information not "
                    "formatted properly"
                )
            for displace in sp_displacements:
                if not isinstance(displace, (float, int)):
                    raise TypeError(
                        "Translate probability information not "
                        "formatted properly"
                    )

        inp_data += """
# Prob_Translation
{prob_translate}""".format(
            prob_translate=translate[0]
        )

        for sp_displacements in translate[1:]:
            inp_data += "\n"
            for max_displace in sp_displacements:
                inp_data += """{} """.format(max_displace)
        inp_data += """
!------------------------------------------------------------------------------
"""

    # Rotation
    if "rotate" in kwargs:
        rotate = kwargs["rotate"]
        if not isinstance(rotate, list):
            raise TypeError(
                "Rotation probability information not " "formatted properly"
            )
        if not isinstance(rotate[0], (float, int)):
            raise TypeError(
                "Probability of rotation move must be "
                "a floating point value"
            )
        for sp_displacements in rotate[1:]:
            if not isinstance(sp_displacements, list):
                raise TypeError(
                    "Rotation probability information not "
                    "formatted properly"
                )
            for displace in sp_displacements:
                if not isinstance(displace, (float, int)):
                    raise TypeError(
                        "Rotation probability information not "
                        "formatted properly"
                    )

        inp_data += """
# Prob_Rotation
{prob_rotate}""".format(
            prob_rotate=rotate[0]
        )

        for sp_displacements in rotate[1:]:
            inp_data += "\n"
            for max_displace in sp_displacements:
                inp_data += """{} """.format(max_displace)

        inp_data += """
!------------------------------------------------------------------------------
"""

    # Angle
    if "angle" in kwargs:
        angle = kwargs["angle"]
        if not isinstance(angle, (float, int)):
            raise TypeError(
                "Angle probability information not " "formatted properly"
            )

        inp_data += """
# Prob_Angle
{prob_angle}""".format(
            prob_angle=angle
        )
        inp_data += """
!------------------------------------------------------------------------------
"""

    # Dihedral
    if "dihed" in kwargs:
        dihed = kwargs["dihed"]
        if not isinstance(dihed, list):
            raise TypeError(
                "Dihedral probability information not " "formatted properly"
            )
        if not isinstance(dihed[0], (float, int)):
            raise TypeError(
                "Probability of dihedral move must be "
                "a floating point value"
            )
        for sp_displacements in dihed[1:]:
            if not isinstance(sp_displacements, list):
                raise TypeError(
                    "Dihedral probability information not "
                    "formatted properly"
                )
            for displace in sp_displacements:
                if not isinstance(displace, (float, int)):
                    raise TypeError(
                        "Dihedral probability information not "
                        "formatted properly"
                    )

        inp_data += """
# Prob_Dihedral
{prob_dihed}""".format(
            prob_dihed=dihed[0]
        )

        for sp_displacements in dihed[1:]:
            inp_data += "\n"
            for max_displace in sp_displacements:
                inp_data += """{} """.format(max_displace)

        inp_data += """
!------------------------------------------------------------------------------
"""

    # Regrowth
    if "regrow" in kwargs:
        regrow = kwargs["regrow"]
        if not isinstance(regrow, list):
            raise TypeError(
                "Regrowth probability information not " "formatted properly"
            )
        if len(regrow) != 2:
            raise TypeError(
                "Regrowth probability information not " "formatted properly"
            )
        if not isinstance(regrow[0], (float, int)):
            raise TypeError(
                "Probability of regrowth move must be "
                "a floating point value"
            )
        if not isinstance(regrow[1], list):
            raise TypeError(
                "Regrowth probability information not " "formatted properly"
            )
        for sp_probs in regrow[1]:
            if not isinstance(sp_probs, (float, int)):
                raise TypeError(
                    "Probability of selecting each "
                    "species for a regrowth move must be a "
                    "floating point value"
                )

        inp_data += """
# Prob_Regrowth
{prob_regrow}
""".format(
            prob_regrow=regrow[0]
        )

        for sp_prob in regrow[1]:
            inp_data += """{} """.format(sp_prob)

        inp_data += """
!------------------------------------------------------------------------------
"""

    # Volume
    if "volume" in kwargs:
        volume = kwargs["volume"]
        if not isinstance(volume, list):
            raise TypeError(
                "Volume probability information not " "formatted properly"
            )
        if len(volume) != 2:
            raise TypeError(
                "Volume probability information not " "formatted properly"
            )
        if not isinstance(volume[0], (float, int)):
            raise TypeError(
                "Probability of volume move must be " "a floating point value"
            )
        if not isinstance(volume[1], list):
            raise TypeError(
                "Volume probability information not " "formatted properly"
            )
        for max_displace in volume[1]:
            if not isinstance(max_displace, (float, int)):
                raise TypeError(
                    "Max displacement for volume move "
                    "must be a floating point value"
                )

        inp_data += """
# Prob_Volume
{prob_volume}""".format(
            prob_volume=volume[0]
        )

        for max_displace in volume[1]:
            inp_data += """
{max_displace}""".format(
                max_displace=max_displace
            )

        inp_data += """
!------------------------------------------------------------------------------
"""

    # Insert/delete
    if "insert" in kwargs:
        insert = kwargs["insert"]
        if not isinstance(insert, list):
            raise TypeError(
                "Insertion probability information not " "formatted properly"
            )
        if len(insert) != 2:
            raise TypeError(
                "Insertion probability information not " "formatted properly"
            )
        if not isinstance(insert[0], (float, int)):
            raise TypeError(
                "Probability of insertion move must be "
                "a floating point value"
            )
        if not isinstance(insert[1], list):
            raise TypeError(
                "Insertion probability information not " "formatted properly"
            )
        for insertable in insert[1]:
            if not isinstance(insertable, bool):
                raise TypeError(
                    "Whether or not a species is insertable "
                    "must be a boolean value"
                )

        inp_data += """
# Prob_Insertion
{prob_insert}
""".format(
            prob_insert=insert[0]
        )

        # Check if there are restricted_insertions
        if "restricted_insertion" in kwargs:
            restriction = kwargs["restricted_insertion"]
            for insertable, restricted in zip(insert[1], restriction[0][0]):
                if insertable and restricted:
                    inp_data += """restricted """
                elif insertable and not restricted:
                    inp_data += """cbmc """
                else:
                    inp_data += """none """
        else:
            for insertable in insert[1]:
                if insertable:
                    inp_data += """cbmc """
                else:
                    inp_data += """none """

        inp_data += """
!------------------------------------------------------------------------------
"""
        inp_data += """
# Prob_Deletion
{prob_insert}
""".format(
            prob_insert=insert[0]
        )

        inp_data += """
!------------------------------------------------------------------------------
"""

    # Swap
    if "swap" in kwargs:
        swap = kwargs["swap"]
        if not isinstance(swap, list):
            raise TypeError(
                "Swap probability information not " "formatted properly"
            )
        if len(swap) != 4:
            raise TypeError(
                "Swap probability information not " "formatted properly"
            )
        if not isinstance(swap[0], (float, int)):
            raise TypeError(
                "Probability of swap move must be " "a floating point value"
            )
        if not isinstance(swap[1], list):
            raise TypeError(
                "Insertion probability information not " "formatted properly"
            )
        for insertable in swap[1]:
            if not isinstance(insertable, bool):
                raise TypeError(
                    "Whether or not a species is insertable"
                    "must be a boolean value"
                )
        if swap[2] is not None:
            if not isinstance(swap[2], list):
                raise TypeError(
                    "Swap probability information not " "formatted properly"
                )
            for prob in swap[2]:
                if not isinstance(prob, (float, int)):
                    raise TypeError(
                        "Probability of selecting species "
                        "for a swap move must be a floating point "
                        "value"
                    )
        if swap[3] is not None:
            if not isinstance(swap[3], list):
                raise TypeError(
                    "Swap probability information not " "formatted properly"
                )
            for prob in swap[3]:
                if not isinstance(prob, (float, int)):
                    raise TypeError(
                        "Probability of selecting box "
                        "as donor for a swap move must be a "
                        "floating point value"
                    )

        inp_data += """
# Prob_Swap
{prob_swap}
""".format(
            prob_swap=swap[0]
        )
        if "restricted_insertion" in kwargs:
            restriction = kwargs["restricted_insertion"]
            for insertable, restricted in zip(swap[1], restriction[0][0]):
                if insertable and restricted:
                    inp_data += """restricted """
                elif insertable and not restricted:
                    inp_data += """cbmc """
                else:
                    inp_data += """none """
        else:
            for insertable in swap[1]:
                if insertable:
                    inp_data += """cbmc """
                else:
                    inp_data += """none """

        if swap[2] is not None:
            inp_data += """
prob_swap_species """
            for prob in swap[2]:
                inp_data += "{} ".format(prob)

        if swap[3] is not None:
            inp_data += """
prob_swap_from_box """
            for prob in swap[3]:
                inp_data += "{} ".format(prob)
        inp_data += """
!------------------------------------------------------------------------------
"""

    inp_data += """
# Done_Probability_Info
!------------------------------------------------------------------------------
"""

    return inp_data


def get_start_type(start_types):
    """Get the Start_Type section of the input file

    Parameters
    ----------
    start_types : list
         list of start_type with one for each box
    """

    inp_data = """
# Start_Type"""

    for start_type in start_types:
        inp_data += """
{start_type}""".format(
            start_type=start_type
        )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_run_type(run_type, thermal_freq, vol_freq=None):
    """Get the Run_Type section of the input file

    Parameters
    ----------
    run_type : string
         'equilibration' or 'production'
    thermal_freq : int
         frequency of updating thermal move displacement
         widths/output statistics
    vol_freq : int
         frequency of updating volume move displacement
         widths/output statistics
    """

    valid_run_types = ["equilibration", "production"]
    if run_type not in valid_run_types:
        raise ValueError(
            "Invalid run type specified {} "
            "Allowable options include {}".format(run_type, valid_run_types)
        )
    if not isinstance(thermal_freq, int):
        raise ValueError("thermal_freq must be an integer")
    if vol_freq is not None and not isinstance(vol_freq, int):
        raise ValueError("thermal_freq must be an integer")

    inp_data = """
# Run_Type
{run_type} {thermal_freq} """.format(
        run_type=run_type, thermal_freq=thermal_freq
    )

    if vol_freq is not None:
        inp_data += "{vol_freq}".format(vol_freq=vol_freq)

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_simulation_length_info(
    units,
    prop_freq,
    coord_freq,
    run_length,
    steps_per_sweep=None,
    block_avg_freq=None,
):
    """Get the Simulation_Length_Info section of the input file

    Parameters
    ----------
    units : string
         'minutes', 'steps', or 'sweeps'
    prop_freq : int
         frequency of writing property info
    coord_freq : int
         frequency of writing coordinates to file
    run_length : int
         number of (units) to run the simulation
    steps_per_sweep : int, optional
         number of steps in a MC sweep
    block_avg_freq : int, optional
         write properties as block averages, averaged over
         block_avg_freq (units)
    """

    valid_units = ["minutes", "steps", "sweeps"]
    if units not in valid_units:
        raise ValueError(
            "Invalid units specified {} Allowable options "
            "include {}".format(units, valid_units)
        )
    if not isinstance(prop_freq, int):
        raise ValueError("prop_freq must be an integer")
    if not isinstance(coord_freq, int):
        raise ValueError("coord_freq must be an integer")
    if not isinstance(run_length, int):
        raise ValueError("run_length must be an integer")
    if steps_per_sweep is not None:
        if not isinstance(steps_per_sweep, int):
            raise ValueError("steps_per_sweep must be an integer")
    if block_avg_freq is not None:
        if not isinstance(block_avg_freq, int):
            raise ValueError("block_avg_freq must be an integer")

    inp_data = """
# Simulation_Length_Info
units {units}
prop_freq {prop_freq}
coord_freq {coord_freq}
run {run_length}""".format(
        units=units,
        prop_freq=prop_freq,
        coord_freq=coord_freq,
        run_length=run_length,
    )

    if steps_per_sweep is not None:
        inp_data += """
steps_per_sweep {steps_per_sweep}
""".format(
            steps_per_sweep=steps_per_sweep
        )
    if block_avg_freq is not None:
        inp_data += """
block_averages {block_avg_freq}
""".format(
            block_avg_freq=block_avg_freq
        )

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data


def get_property_info(properties, nbr_boxes):
    """Get the Property_Info section of the input file
    Parameters
    ----------
    properties : list
         desired properties to output
    """
    if not isinstance(properties, list):
        raise TypeError(
            "Properties should be specified "
            "as a list of properties to be printed"
        )

    valid_properties = [
        "energy_total",
        "energy_intra",
        "energy_bond",
        "energy_angle",
        "energy_dihedral",
        "energy_improper",
        "energy_intravdw",
        "energy_intraq",
        "energy_inter",
        "energy_intervdw",
        "energy_lrc",
        "energy_interq",
        "energy_recip",
        "energy_self",
        "enthalpy",
        "pressure",
        "pressure_xx",
        "pressure_yy",
        "pressure_zz",
        "volume",
        "nmols",
        "density",
        "mass_density",
    ]

    for prop in properties:
        if prop not in valid_properties:
            raise ValueError(
                "Invalid property: {}. Valid choices "
                "include {}".format(prop, valid_properties)
            )

    inp_data = ""
    for ibox in range(nbr_boxes):
        inp_data += """
# Property_Info {}
""".format(
            ibox + 1
        )

        for prop in properties:
            inp_data += "{}\n".format(prop)

    return inp_data


def get_fragment_files(files=None):

    inp_data = """
# Fragment_Files
"""
    if files is not None:
        for ifile in files:
            inp_data += ifile
    inp_data += """!------------------------------------------------------------------------------
"""
    return inp_data


def get_verbose_log(verbose):

    if not isinstance(verbose, bool):
        raise TypeError("Verbosity must be a boolean")

    if verbose:
        verbose = "true"
    else:
        verbose = "false"

    inp_data = """
# Verbose_Logfile
{verbose}
!------------------------------------------------------------------------------
""".format(
        verbose=verbose
    )

    return inp_data


def get_cbmc_info(n_insert, n_dihed, cutoffs):
    """Get the CBMC_Info section of the input file

    Parameters
    ----------
    n_insert : int
        number of insertion sites to attempt for CBMC
    n_dihed : int
        number of dihedral angles to attempt for CBMC
    cutoffs : list
        list containing CBMC cutoff values for each box
    """

    if not isinstance(n_insert, int):
        raise TypeError(
            "Number of CBMC insertion attempts must be " "an integer"
        )
    if not isinstance(n_dihed, int):
        raise TypeError(
            "Number of CBMC dihedral angle attempts must be " "an integer"
        )
    if not isinstance(cutoffs, list):
        raise TypeError("Cutoff information improperly specified")
    for cutoff in cutoffs:
        if not isinstance(cutoff, (float, int)):
            raise TypeError("CBMC cutoff must be a float")

    inp_data = """
# CBMC_Info
kappa_ins {n_insert}
kappa_dih {n_dihed}
rcut_cbmc""".format(
        n_insert=n_insert, n_dihed=n_dihed
    )

    for cutoff in cutoffs:
        inp_data += " {cutoff}".format(cutoff=cutoff)

    inp_data += """
!------------------------------------------------------------------------------
"""

    return inp_data





def _get_possible_kwargs(desc=False):
    valid_kwargs = {
        "run_name": "str, name of output",
        "verbose_log": "boolean, write verbose log file",
        "vdw_style": 'str, "lj" or "none"',
        "cutoff_style": 'str, "cut" or "cut_tail" or "cut_switch" or "cut_shift"',
        "vdw_cutoff": 'unyt_array or unyt_quantity with `length` units, except for "cut_switch", where [inner_cutoff, outer_cutoff].',
        "vdw_cutoff_box1": 'customize vdw cutoff for box 1. see "vdw_cutoff" for format',
        "vdw_cutoff_box2": 'customize vdw cutoff for box 2. see "vdw_cutoff" for format',
        "charge_style": 'str, "none" or "cut" or "ewald" or "dsf"',
        "charge_cutoff": "unyt_array or unyt_quantity with `length` units",
        "charge_cutoff_box1": 'customize charge cutoff for box 1. see "charge_cutoff" for format',
        "charge_cutoff_box2": 'customize charge cutoff for box 2. see "charge_cutoff" for format',
        "ewald_accuracy": "float, accuracy of ewald sum",
        "dsf_damping": "float, damping parameter for dsf charge method",
        "mixing_rule": 'str, "lb" or "geometric" or "custom"',
        "custom_mixing_dict": "dict, one key-value pair per atomtype-pair, key=str of species comb, value=str of params",
        "seeds": "list of ints, [seed1,seed2], where each seed is an integer",
        "rcut_min": "unyt_array or unyt_quantity with `length` units, automatically reject move if atoms are closer than this distance",
        "pair_energy": "boolean, store pair energies (faster but requires more memory)",
        "max_molecules": "list of ints, maximum number of molecules for each species",
        "pressure": "unyt_array or unyt_quantity with `pressure` units, desired pressure (npt and gemc-npt)",
        "pressure_box1": 'customize pressure for box 1. see "pressure" for format',
        "pressure_box2": 'customize pressure for box 2. see "pressure" for format',
        "chemical_potentials": "list of unyt_array or unyt_quantity with units of `energy/mol`, specify the desired chemical potential for each species",
        "thermal_stat_freq": "int, frequency of printing/updating non-volume moves",
        "vol_stat_freq": "int, frequency of printing/updating volume moves",
        "units": 'str, units for run/thermo/coord run_length/freqs. "minutes" or "steps" or "sweeps"',
        "prop_freq": "int, frequency of writing thermo properties",
        "coord_freq": "int, frequency of writing coordinates",
        "steps_per_sweep": "int, number of MC steps defined as a single sweep",
        "block_avg_freq": "int, block average size",
        "properties": (
            "list, properties to write to thermo file. Valid options include "
            '"energy_total", "energy_lj", "energy_elec", "energy_intra", "enthalpy",'
            '"pressure", "volume", "nmols", "density", "mass_density"'
        ),
        "angle_style": "list of str, angle style for each species",
    }
    if desc:
        return valid_kwargs
    else:
        return list(valid_kwargs.keys())


def _check_restricted_insertions(box, restriction_type, restriction_value):
    """Check that restricted insertion values are valid given the box size

    Note: Only checking cubic boxes currently
    """
    box_max = np.array([box[0][0], box[1][1], box[2][2]])
    if restriction_type in ["cylinder", "sphere"]:
        if np.any(restriction_value.to_value() * 2 > box_max):
            raise ValueError(
                "Restricted insertion 'r_max' value is"
                " greater than the box coordinates."
            )
    elif restriction_type == "slitpore":
        if restriction_value.to_value() * 2 > box_max[2]:
            raise ValueError(
                "Restricted insertion 'z_max' value is"
                " greater than the z-coordinate of the box."
            )
    elif restriction_type == "interface":
        interface_z = (
            restriction_value[1].to_value() - restriction_value[0].to_value()
        )
        if restriction_value[1].to_value() > box_max[2]:
            raise ValueError(
                "Restricted insertion 'z_max' passed"
                " for 'interface' is"
                " greater than the z-coordinate of the box."
            )
        elif interface_z > box_max[2]:
            raise ValueError(
                "Restricted insertion value passed"
                " for 'interface' is"
                " greater than the z-coordinate of the box."
            )


def _check_kwarg_units(kwargs):
    """Check the units of kwargs"""
    _check_kwarg_units_helper(
        kwargs, "vdw_cutoff", dimensions.length, list_length=2
    )
    _check_kwarg_units_helper(
        kwargs, "vdw_cutoff_box1", dimensions.length, list_length=2
    )
    _check_kwarg_units_helper(
        kwargs, "vdw_cutoff_box2", dimensions.length, list_length=2
    )
    _check_kwarg_units_helper(kwargs, "charge_cutoff", dimensions.length)
    _check_kwarg_units_helper(kwargs, "charge_cutoff_box1", dimensions.length)
    _check_kwarg_units_helper(kwargs, "charge_cutoff_box2", dimensions.length)
    _check_kwarg_units_helper(kwargs, "rcut_min", dimensions.length)
    _check_kwarg_units_helper(kwargs, "pressure", dimensions.pressure)
    _check_kwarg_units_helper(kwargs, "pressure_box1", dimensions.pressure)
    _check_kwarg_units_helper(kwargs, "pressure_box2", dimensions.pressure)

    # Handle chemical potentials here because quirky
    if "chemical_potentials" in kwargs:
        for mu in kwargs["chemical_potentials"]:
            if not isinstance(mu, str):
                validate_unit(mu, dimensions.energy)


def _check_kwarg_units_helper(kwargs, kwarg_name, dimension, list_length=0):
    if kwarg_name not in kwargs:
        return
    if list_length == 0:
        validate_unit(kwargs[kwarg_name], dimension)
        # Make sure length is *actually* 1
        if type(kwargs[kwarg_name]) == u.unyt_array:
            if kwargs[kwarg_name].size > 1:
                raise TypeError(f"Invalid format for argument {kwarg_name}")
    else:
        # Logic checks if we have an array/list vs. single item array/quantity
        if (
            isinstance(kwargs[kwarg_name], (u.unyt_quantity))
            or len(kwargs[kwarg_name]) == 1
        ):
            validate_unit(kwargs[kwarg_name], dimension)
        else:
            kwargs[kwarg_name] = validate_unit_list(
                kwargs[kwarg_name],
                (list_length,),
                dimension,
                kwarg_name,
            )


def _convert_kwarg_units(kwargs):
    """Convert kwargs that are unyt units"""
    _convert_kwarg_units_helper(kwargs, "vdw_cutoff", "angstrom")
    _convert_kwarg_units_helper(kwargs, "vdw_cutoff_box1", "angstrom")
    _convert_kwarg_units_helper(kwargs, "vdw_cutoff_box2", "angstrom")
    _convert_kwarg_units_helper(kwargs, "charge_cutoff", "angstrom")
    _convert_kwarg_units_helper(kwargs, "charge_cutoff_box1", "angstrom")
    _convert_kwarg_units_helper(kwargs, "charge_cutoff_box2", "angstrom")
    _convert_kwarg_units_helper(kwargs, "rcut_min", "angstrom")
    _convert_kwarg_units_helper(kwargs, "pressure", "bar")
    _convert_kwarg_units_helper(kwargs, "pressure_box1", "bar")
    _convert_kwarg_units_helper(kwargs, "pressure_box2", "bar")

    # Handle chemical potentials here because quirky
    if "chemical_potentials" in kwargs:
        new_mu = list()
        for mu in kwargs["chemical_potentials"]:
            if not isinstance(mu, str):
                mu = mu.to("kJ/mol")
            new_mu.append(mu)
        kwargs["chemical_potentials"] = new_mu


def _convert_kwarg_units_helper(kwargs, kwarg_name, unit_name):
    if kwarg_name in kwargs:
        if not isinstance(kwargs[kwarg_name], u.unyt_array):
            raise TypeError(
                f"Something went wrong in converting the units for the "
                f"keyword argument {kwarg_name}. Please check your the format "
                f"of this argument."
            )
        # Convert units (everything should already be a u.unyt_array)
        kwargs[kwarg_name] = kwargs[kwarg_name].to(unit_name)


def _convert_moveset_units(moveset):
    # Convert restricted insertion
    new_restricted_value = list()
    if moveset._restricted_value:
        for box in moveset._restricted_value:
            new_boxvals = list()
            for val in box:
                if val:
                    if isinstance(val, list):
                        val = [i.to("angstrom") for i in val]
                    else:
                        val = val.to("angstrom")
                new_boxvals.append(val)
            new_restricted_value.append(new_boxvals)
    moveset._restricted_value = new_restricted_value

    moveset.max_translate = moveset.max_translate.to("angstrom")
    moveset.max_rotate = moveset.max_rotate.to("degree")
    moveset.max_dihedral = moveset.max_dihedral.to("degree")
    moveset.max_volume = moveset.max_volume.to("angstrom**3")
    moveset.cbmc_rcut = moveset.cbmc_rcut.to("angstrom")

    return moveset