Source code for crowsetta.validation

"""Module with functions for data validation.

Some utilities adapted from scikit-learn under BSD 3 License
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/validation.py
"""
import numbers
from pathlib import PurePath
from typing import Sequence, Union

import numpy as np
import numpy.typing as npt

from .typing import PathLike


def _num_samples(x: npt.ArrayLike) -> int:
    """Return number of samples in array-like x."""
    if not hasattr(x, "__len__") and not hasattr(x, "shape"):
        if hasattr(x, "__array__"):
            x = np.asarray(x)
        else:
            raise TypeError("Expected sequence or array-like, got %s" % type(x))
    if hasattr(x, "shape"):
        if len(x.shape) == 0:
            raise TypeError("Singleton array %r cannot be considered" " a valid collection." % x)
        # Check that shape is returning an integer or default to len
        # Dask dataframes may not return numeric shape[0] value
        if isinstance(x.shape[0], numbers.Integral):
            return x.shape[0]
        else:
            return len(x)
    else:
        return len(x)


[docs] def check_consistent_length(arrays: Sequence[npt.ArrayLike]) -> None: """Check that all arrays have consistent first dimensions. Checks whether all objects in arrays have the same shape or length. Parameters ---------- arrays : list or tuple of input objects. Objects that will be checked for consistent length. """ lengths = [_num_samples(X) for X in arrays if X is not None] uniques = np.unique(lengths) if len(uniques) > 1: raise ValueError( "Found input variables with inconsistent numbers of" " samples: %r" % [int(length) for length in lengths] )
[docs] def column_or_row_or_1d(y: npt.NDArray) -> npt.NDArray: """Ravel column or row vector or 1d numpy array, else raises an error Parameters ---------- y : array-like Returns ------- y : array """ shape = np.shape(y) if (len(shape) == 1) or (len(shape) == 2 and (shape[1] == 1 or shape[0] == 1)): return np.ravel(y) else: raise ValueError("bad input shape {0}".format(shape))
[docs] def validate_ext(file: PathLike, extension: Union[str, tuple]) -> None: """Check that a file has a valid extension. Parameters ---------- file : str, pathlib.Path Path to a file. extension : str, tuple Valid file extension(s). Tuple must be tuple of strings. Function expects that extensions will be specified with a period, e.g. {'.phn', '.PHN'} """ if isinstance(extension, str): extension = (extension,) elif isinstance(extension, tuple): if not all([isinstance(element, str) for element in extension]): raise TypeError( "Must specify all valid extensions as strings, but value was \n" f"'{extension}' with types: {[type(element) for element in extension]}" ) else: raise TypeError(f"Extension must be str or tuple but type was {type(extension)}") if not (isinstance(file, str) or isinstance(file, PurePath)): raise TypeError(f"File must be a str or a pathlib.Path, but type of file was {type(file)}.\n" f"File: {file}") # we need to use `endswith` instead of # e.g. comparing with `pathlib.Path.suffix` # because suffix won't work for "multi-part" extensions like '.not.mat' if not any([str(file).endswith(ext) for ext in extension]): raise ValueError(f"Invalid extension for file: {file}.\n" f"Valid extension(s): '{extension}'")