Source code for pydrobert.torch.argcheck

# Copyright 2023 Sean Robertson

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Boilerplate for checking argument values

There are two broad types of function defined in this submodule: ``is_*`` and ``as_*``.

The ``is_*`` functions check if the passed value satisfies some requirement, usually
being of some type. They are intended to check arguments being passed to objects
(primarily :class:`torch.nn.Module` objects) on initialization. Some accept other types
(e.g. :func:`is_float` accepts :class:`int`, :class:`np.integer`, and
:class:`np.floating` in addition to :class:`float`) which will quietly be cast to the
expected type before returning. Most other ``is_*`` functions check whether the value
satisfies some condition (e.g. being a member of a collection, :func:`is_in`, or being
positive, :func:`is_pos`). Some, e.g. :func:`is_nat`, combine type checks with
conditions (:func:`is_int` and :func:`is_pos`).

The ``as_*`` functions are more agressive, casting their first argument to the type
immediately, then possibly checking a condition. They are intended primarly for use
as the ``type`` argument in :func:`argparser.ArgumentParser.add_argument`.
"""

import os
import string

from typing import (
    Collection,
    Optional,
    TypeVar,
    Callable,
    Any,
    Type,
    Union,
    Generic,
    cast,
)
from typing_extensions import (
    overload,
    Literal,
    get_args,
    ParamSpec,
    Concatenate,
    Protocol,
)
from pathlib import Path

import torch
import numpy as np

__all__ = [
    "as_bool",
    "as_closed01",
    "as_dir",
    "as_file",
    "as_float",
    "as_int",
    "as_nat",
    "as_negf",
    "as_negi",
    "as_nonnegf",
    "as_nonnegi",
    "as_nonposf",
    "as_nonposi",
    "as_open01",
    "as_path_dir",
    "as_path_file",
    "as_path",
    "as_posf",
    "as_posi",
    "as_str",
    "as_tensor",
    "has_ndim",
    "is_a",
    "is_bool",
    "is_btw_closed",
    "is_btw_closedf",
    "is_btw_closedi",
    "is_btw_closedt",
    "is_btw_open",
    "is_btw_openf",
    "is_btw_openi",
    "is_btw_opent",
    "is_btw",
    "is_btwf",
    "is_btwi",
    "is_btwt",
    "is_closed01",
    "is_closed01f",
    "is_closed01i",
    "is_closed01t",
    "is_dir",
    "is_equal",
    "is_equalf",
    "is_equali",
    "is_equalt",
    "is_exactly",
    "is_file",
    "is_float",
    "is_gt",
    "is_gte",
    "is_gtef",
    "is_gtei",
    "is_gtet",
    "is_gtf",
    "is_gti",
    "is_gtt",
    "is_in",
    "is_int",
    "is_lt",
    "is_lte",
    "is_ltef",
    "is_ltei",
    "is_ltet",
    "is_ltf",
    "is_lti",
    "is_ltt",
    "is_nat",
    "is_neg",
    "is_negf",
    "is_negi",
    "is_negt",
    "is_nonempty",
    "is_nonneg",
    "is_nonnegf",
    "is_nonnegi",
    "is_nonnegt",
    "is_nonpos",
    "is_nonposf",
    "is_nonposi",
    "is_nonpost",
    "is_numlike",
    "is_open01",
    "is_open01f",
    "is_open01i",
    "is_open01t",
    "is_path",
    "is_pos",
    "is_posf",
    "is_posi",
    "is_post",
    "is_str",
    "is_tensor",
    "is_token",
]

V1 = TypeVar("V1")
V2 = TypeVar("V2")
NumLike = Union[torch.Tensor, float, int, np.floating, np.integer]
N = TypeVar("N", bound=NumLike)
P = ParamSpec("P")
StrOrPathLike = Union[str, os.PathLike]


class _IsCheck(Protocol[V1]):
    # if allow_none is the literal "False" (also the default), we can narrow to the type
    # of interest. If allow_none is true or is variable, we have to assume it's optional
    @overload
    def __call__(
        self, val: V1, name: Optional[str] = None, allow_none: Literal[False] = False
    ) -> V1:
        ...

    @overload
    def __call__(
        self, val: Optional[V1], name: Optional[str] = None, allow_none: bool = False
    ) -> Optional[V1]:
        ...


def _is_check_allow_none(wrapped: Callable[..., V1]) -> _IsCheck[V1]:
    def wrapper(val, name=None, allow_none=False):
        if allow_none and val is None:
            return val
        return wrapped(val, name)

    return wrapper


def _nv(name: Optional[str], val: Any) -> str:
    if isinstance(val, torch.Tensor):
        if val.numel() == 1:
            return f"{val.item()}" if name is None else f"{name} ({val.item()})"
        else:
            return name if name is not None else "tensor"
    else:
        if isinstance(val, str):
            val = f"'{val}'"
        return f"{val}" if name is None else f"{name} ({val})"


def _type_check_factory(t: Type[V1], *ts: type):
    ts = (t,) + ts

    @_is_check_allow_none
    def _is_check(val, name=None) -> t:
        if isinstance(val, ts):
            return val if (type(val) is t) else t(val)
        else:
            tname = t.__name__
            x = "n" if tname.startswith(("a", "e", "i", "o", "u")) else ""
            raise ValueError(f"{_nv(name, val)} is not a{x} {tname}")

    return _is_check


is_str = _type_check_factory(str)
is_int = _type_check_factory(int, np.integer)
is_bool = _type_check_factory(bool)
is_float = _type_check_factory(float, int, np.integer, np.floating)
is_tensor = _type_check_factory(torch.Tensor)
is_path = _type_check_factory(Path, *get_args(StrOrPathLike))


[docs] @_is_check_allow_none def is_numlike(val, name=None) -> NumLike: if not isinstance(val, get_args(NumLike)): raise ValueError(f"{_nv(name, val)} is not num-like {get_args(N)}") return val
@overload def is_token( val: str, name: Optional[str] = None, empty_okay: bool = False, whitespace: str = string.whitespace, allow_none: Literal[False] = False, ) -> str: ... @overload def is_token( val: Optional[str], name: Optional[str] = None, empty_okay: bool = False, whitespace: str = string.whitespace, allow_none: bool = False, ) -> Optional[str]: ...
[docs] def is_token( val, name=None, empty_okay=False, whitespace=string.whitespace, allow_none=False ): if val is None and allow_none: return val val = is_str(val, name) if not empty_okay and not len(val): raise ValueError(f"{_nv(name, val)} is empty") else: for w in whitespace: if w in val: raise ValueError(f"{_nv(name, val)} contains '{w}'") return val
@overload def is_a( val: V1, t: Type[V1], name: Optional[str], allow_none: Literal[False] = False, ) -> V1: ... @overload def is_a( val: Optional[V1], t: Type[V1], name: Optional[str] = None, allow_none: bool = False, ) -> Optional[V1]: ...
[docs] def is_a(val, t, name=None, allow_none=False): assert not issubclass(t, torch.nn.Module), f"what if {t} is scripted?" if allow_none and val is None: return None if not isinstance(val, t): suf = "n" if t.__name__.startswith(("a", "e", "i", "o", "u")) else "" raise ValueError(f"{_nv(name, val)} is not a{suf} {t.__name__}") return val
@overload def is_in( val: V1, collection: Collection[V1], name: Optional[str] = None, allow_none: Literal[False] = False, ) -> V1: ... @overload def is_in( val: Optional[V1], collection: Collection[V1], name: Optional[str] = None, allow_none: bool = False, ) -> Optional[V1]: ...
[docs] def is_in(val, collection, name=None, allow_none=False): if allow_none and val is None: return None if val not in collection: raise ValueError(f"{_nv(name, val)} is not one of {collection}") return val
[docs] @_is_check_allow_none def is_file(val: StrOrPathLike, name: Optional[str] = None) -> StrOrPathLike: if not os.path.isfile(val): raise ValueError(f"{_nv(name, val)} is not a file") return val
[docs] @_is_check_allow_none def is_dir(val: StrOrPathLike, name: Optional[str] = None) -> StrOrPathLike: if not os.path.isdir(val): raise ValueError(f"{_nv(name, val)} is not a directory") return val
[docs] @_is_check_allow_none def is_pos(val, name=None) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) if (val_ <= 0).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError(f"{_nv(name, val)} is not {x}positive") return val
[docs] @_is_check_allow_none def is_neg(val, name=None) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) if (val_ >= 0).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError(f"{_nv(name, val)} is not {x}negative") return val
[docs] @_is_check_allow_none def is_nonpos(val, name=None) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) if (val_ > 0).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError(f"{_nv(name, val)} is not {x}non-positive") return val
[docs] @_is_check_allow_none def is_nonneg(val: N, name=None) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) if (val_ < 0).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError(f"{_nv(name, val)} is not {x}non-negative") return val
class _CompareProtocol(Protocol, Generic[V1, V2]): @overload def __call__( self, val: V1, other: V2, name: Optional[str] = None, other_name: Optional[str] = None, allow_none: Literal[False] = False, ) -> V1: ... @overload def __call__( self, val: Optional[V1], other: V2, name: Optional[str] = None, other_name: Optional[str] = None, allow_none: bool = False, ) -> Optional[V1]: ... def _compare_allow_none( func: Callable[Concatenate[V1, V2, P], V1] ) -> _CompareProtocol[V1, V2]: def _allow_none(val, other, name=None, other_name=None, allow_none=False): if val is None and allow_none: return val return func(val, other, name, other_name) return _allow_none
[docs] @_compare_allow_none def is_exactly(val: Any, other: Any, name=None, other_name=None) -> Any: if val is not other: raise ValueError(f"{_nv(name, val)} is not {_nv(other_name, other)}") return val
[docs] @_compare_allow_none def is_equal(val: NumLike, other: NumLike, name=None, other_name=None) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) other_ = torch.as_tensor(is_numlike(other, other_name), device=val_.device) if (val_ != other_).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError( f"{_nv(name, val)} is not {x}equal to {_nv(other_name, other)}" ) return val
[docs] @_compare_allow_none def is_lt( val: NumLike, other: NumLike, name=None, other_name=None, ) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) other_ = torch.as_tensor(is_numlike(other, other_name), device=val_.device) if (val_ >= other_).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError( f"{_nv(name, val)} is not {x}less than {_nv(other_name, other)}" ) return val
[docs] @_compare_allow_none def is_gt( val: NumLike, other: NumLike, name: Optional[str] = None, other_name: Optional[str] = None, ) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) other_ = torch.as_tensor(is_numlike(other, other_name), device=val_.device) if (val_ <= other_).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError( f"{_nv(name, val)} is not {x}greater than {_nv(other_name, other)}" ) return val
[docs] @_compare_allow_none def is_lte( val: NumLike, other: NumLike, name: Optional[str] = None, other_name: Optional[str] = None, ) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) other_ = torch.as_tensor(is_numlike(other, other_name), device=val_.device) if (val_ > other_).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError( f"{_nv(name, val)} is not {x}less than or equal to " f"{_nv(other_name, other)}" ) return val
[docs] @_compare_allow_none def is_gte( val: NumLike, other: NumLike, name: Optional[str] = None, other_name: Optional[str] = None, ) -> NumLike: val_ = torch.as_tensor(is_numlike(val, name)) other_ = torch.as_tensor(is_numlike(other, other_name), device=val_.device) if (val_ < other_).any(): x = "entirely " if val_.numel() > 1 else "" raise ValueError( f"{_nv(name, val)} is not {x}greater than or equal to" f"{_nv(other_name, other)}" ) return val
@overload def is_btw( val: NumLike, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, left_inclusive: bool = False, right_inclusive: bool = False, allow_none: Literal[False] = False, ) -> NumLike: ... @overload def is_btw( val: Optional[NumLike], left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, left_inclusive: bool = False, right_inclusive: bool = False, allow_none: bool = False, ) -> Optional[NumLike]: ...
[docs] def is_btw( val, left, right, name=None, left_name=None, right_name=None, left_inclusive=False, right_inclusive=False, allow_none=False, ): if allow_none and val is None: return val val_ = torch.as_tensor(is_numlike(val, name)) try: if left_inclusive: val_ = is_gte(val_, left, name, left_name) else: val_ = is_gt(val_, left, name, left_name) if right_inclusive: val_ = is_lte(val_, right, name, right_name) else: val_ = is_lt(val_, right, name, right_name) except ValueError: x = "entirely " if val_.numel() > 1 else "" y = "incl." if left_inclusive else "excl." z = "incl." if right_inclusive else "excl." raise ValueError( f"{_nv(name, val)} is not {x}within {_nv(left_name, left)} {y} and " f"{_nv(right_name, right)} {z}" ) return val
@overload def is_btw_open( val: NumLike, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: Literal[False] = False, ) -> NumLike: ... @overload def is_btw_open( val: Optional[NumLike], left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: bool = False, ) -> Optional[NumLike]: ...
[docs] def is_btw_open( val, left, right, name=None, left_name=None, right_name=None, allow_none: bool = False, ): return is_btw( val, left, right, name, left_name, right_name, False, False, allow_none )
@overload def is_btw_closed( val: NumLike, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: Literal[False] = False, ) -> NumLike: ... @overload def is_btw_closed( val: Optional[NumLike], left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: bool = False, ) -> Optional[NumLike]: ...
[docs] def is_btw_closed( val, left, right, name=None, left_name=None, right_name=None, allow_none=False, ): return is_btw(val, left, right, name, left_name, right_name, True, True, allow_none)
[docs] def is_open01(val, name=None, allow_none=False): return is_btw(val, 0, 1, name, None, None, False, False, allow_none)
is_open01 = cast(_IsCheck[NumLike], is_open01)
[docs] def is_closed01(val, name=None, allow_none=False): return is_btw(val, 0, 1, name, None, None, True, True, allow_none)
is_closed01 = cast(_IsCheck[NumLike], is_closed01) def _numlike_special_factory(t: Type[N], *ts: type): _type_check = _type_check_factory(t, *ts) @_is_check_allow_none def _pos_check(val: t, name=None) -> t: val = _type_check(val, name) return is_pos(val, name) @_is_check_allow_none def _neg_check(val: t, name=None) -> t: val = _type_check(val, name) return is_neg(val, name) @_is_check_allow_none def _nonpos_check(val: t, name=None) -> t: val = _type_check(val, name) return is_nonpos(val, name) @_is_check_allow_none def _nonneg_check(val: t, name=None) -> t: val = _type_check(val, name) return is_nonneg(val, name) @_compare_allow_none def _equal_check(val: t, other: NumLike, name=None, other_name=None) -> t: val = _type_check(val, name) return is_equal(val, other, name, other_name) @_compare_allow_none def _lt_check(val: t, other: NumLike, name=None, other_name=None) -> t: val = _type_check(val, name) return is_lt(val, other, name, other_name) @_compare_allow_none def _lte_check( val: t, other: NumLike, name=None, other_name=None, ) -> t: val = _type_check(val, name) return is_lte(val, other, name, other_name) @_compare_allow_none def _gt_check(val: t, other: NumLike, name=None, other_name=None) -> t: val = _type_check(val, name) return is_gt(val, other, name, other_name) @_compare_allow_none def _gte_check(val: t, other: NumLike, name=None, other_name=None) -> t: val = _type_check(val, name) return is_gte(val, other, name, other_name) @overload def _btw_check( val: t, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, left_inclusive: bool = False, right_inclusive: bool = False, allow_none: Literal[False] = False, ) -> t: ... @overload def _btw_check( val: Optional[t], left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, left_inclusive: bool = False, right_inclusive: bool = False, allow_none: bool = False, ) -> Optional[t]: ... def _btw_check( val, left, right, name=None, left_name=None, right_name=None, left_inclusive=False, right_inclusive=False, allow_none=False, ): val = _type_check(val, name) return is_btw( val, left, right, name, left_name, right_name, left_inclusive, right_inclusive, allow_none, ) @overload def _btw_open_check( val: t, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: Literal[False] = False, ) -> t: ... @overload def _btw_open_check( val: Optional[t], left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: bool = False, ) -> Optional[t]: ... def _btw_open_check( val: t, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none=False, ) -> t: val = _type_check(val, name) return is_btw( val, left, right, name, left_name, right_name, False, False, allow_none ) @overload def _btw_closed_check( val: t, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: Literal[False] = False, ) -> t: ... @overload def _btw_closed_check( val: Optional[t], left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none: bool = False, ) -> Optional[t]: ... def _btw_closed_check( val: t, left: NumLike, right: NumLike, name: Optional[str] = None, left_name: Optional[str] = None, right_name: Optional[str] = None, allow_none=False, ) -> t: val = _type_check(val, name) return is_btw( val, left, right, name, left_name, right_name, True, True, allow_none ) @_is_check_allow_none def _open01_check(val: t, name=None) -> t: val = _type_check(val, name) return is_open01(val, name) @_is_check_allow_none def _closed01_check(val: t, name=None) -> t: val = _type_check(val, name) return is_nonneg(val, name) return ( _pos_check, _neg_check, _nonpos_check, _nonneg_check, _equal_check, _lt_check, _lte_check, _gt_check, _gte_check, _btw_check, _btw_open_check, _btw_closed_check, _open01_check, _closed01_check, ) ( is_posi, is_negi, is_nonposi, is_nonnegi, is_equali, is_lti, is_ltei, is_gti, is_gtei, is_btwi, is_btw_openi, is_btw_closedi, is_open01i, is_closed01i, ) = _numlike_special_factory(int, np.integer) is_nat = is_posi ( is_posf, is_negf, is_nonposf, is_nonnegf, is_equalf, is_ltf, is_ltef, is_gtf, is_gtef, is_btwf, is_btw_openf, is_btw_closedf, is_open01f, is_closed01f, ) = _numlike_special_factory(float, np.floating, int, np.integer) ( is_post, is_negt, is_nonpost, is_nonnegt, is_equalt, is_ltt, is_ltet, is_gtt, is_gtet, is_btwt, is_btw_opent, is_btw_closedt, is_open01t, is_closed01t, ) = _numlike_special_factory(torch.Tensor) @overload def has_ndim( val: torch.Tensor, ndim: int, name: Optional[str] = None, allow_none: Literal[False] = False, ) -> torch.Tensor: ... @overload def has_ndim( val: Optional[torch.Tensor], ndim: int, name: Optional[str] = None, allow_none: bool = False, ) -> Optional[torch.Tensor]: ...
[docs] def has_ndim(val, ndim, name=None, allow_none=False) -> torch.Tensor: if allow_none and val is None: return val if val.ndim != ndim: raise ValueError( f"Expected {_nv(name, val)} to have dimension {ndim}; got {val.ndim}" ) return val
[docs] @_is_check_allow_none def is_nonempty(val, name=None) -> torch.Tensor: if not val.numel(): raise ValueError(f"Expected {_nv(name, val)} to be nonempty") return val
def _cast_factory( cast: Callable[[Any], V1], check: Optional[Callable[[V1, Optional[str]], V1]] = None, cast_name: Optional[str] = None, ): def _cast(val: Any, name: Optional[str] = None) -> V1: try: val = cast(val) if check is not None: val = check(val, name) except: suf = "n" if _cast.__name__.startswith(("a", "e", "i", "o", "u")) else "" raise TypeError( f"Could not cast {_nv(name, val)} as a{suf} {cast.__name__}" ) return val _cast.__name__ = cast.__name__ if cast_name is None else cast_name return _cast as_str = _cast_factory(str) as_int = _cast_factory(int) as_bool = _cast_factory(bool) as_float = _cast_factory(float) as_tensor = _cast_factory(torch.as_tensor, cast_name="tensor") as_posf = _cast_factory(float, is_pos, cast_name="positive float") as_nat = _cast_factory(int, is_pos, cast_name="natural number") as_posi = _cast_factory(int, is_pos, cast_name="positive integer") as_nonnegf = _cast_factory(float, is_nonneg, cast_name="non-negative float") as_nonnegi = _cast_factory(int, is_nonneg, cast_name="non-negative integer") as_negf = _cast_factory(float, is_neg, cast_name="negative float") as_negi = _cast_factory(int, is_neg, cast_name="negative integer") as_nonposf = _cast_factory(float, is_nonpos, cast_name="non-positive float") as_nonposi = _cast_factory(int, is_nonpos, cast_name="non-positive integer") as_open01 = _cast_factory(float, is_open01, cast_name="float within (0, 1)") as_closed01 = _cast_factory(float, is_closed01, cast_name="float within [0, 1]") as_path = _cast_factory(Path) as_path_file = _cast_factory(Path, is_file, cast_name="readable file") as_path_dir = _cast_factory(Path, is_dir, cast_name="readable directory") as_file = _cast_factory(str, is_file, cast_name="readable file") as_dir = _cast_factory(str, is_dir, cast_name="readable directory")