summaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/numpy/_typing/_nbit_base.py
blob: 28d3e63c176930cf6b84f91b0bc119b431b6e126 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""A module with the precisions of generic `~numpy.number` types."""
from typing import final

from numpy._utils import set_module


@final  # Disallow the creation of arbitrary `NBitBase` subclasses
@set_module("numpy.typing")
class NBitBase:
    """
    A type representing `numpy.number` precision during static type checking.

    Used exclusively for the purpose of static type checking, `NBitBase`
    represents the base of a hierarchical set of subclasses.
    Each subsequent subclass is herein used for representing a lower level
    of precision, *e.g.* ``64Bit > 32Bit > 16Bit``.

    .. versionadded:: 1.20

    .. deprecated:: 2.3
        Use ``@typing.overload`` or a ``TypeVar`` with a scalar-type as upper
        bound, instead.

    Examples
    --------
    Below is a typical usage example: `NBitBase` is herein used for annotating
    a function that takes a float and integer of arbitrary precision
    as arguments and returns a new float of whichever precision is largest
    (*e.g.* ``np.float16 + np.int64 -> np.float64``).

    .. code-block:: python

        >>> from typing import TypeVar, TYPE_CHECKING
        >>> import numpy as np
        >>> import numpy.typing as npt

        >>> S = TypeVar("S", bound=npt.NBitBase)
        >>> T = TypeVar("T", bound=npt.NBitBase)

        >>> def add(a: np.floating[S], b: np.integer[T]) -> np.floating[S | T]:
        ...     return a + b

        >>> a = np.float16()
        >>> b = np.int64()
        >>> out = add(a, b)

        >>> if TYPE_CHECKING:
        ...     reveal_locals()
        ...     # note: Revealed local types are:
        ...     # note:     a: numpy.floating[numpy.typing._16Bit*]
        ...     # note:     b: numpy.signedinteger[numpy.typing._64Bit*]
        ...     # note:     out: numpy.floating[numpy.typing._64Bit*]

    """
    # Deprecated in NumPy 2.3, 2025-05-01

    def __init_subclass__(cls) -> None:
        allowed_names = {
            "NBitBase", "_128Bit", "_96Bit", "_64Bit", "_32Bit", "_16Bit", "_8Bit"
        }
        if cls.__name__ not in allowed_names:
            raise TypeError('cannot inherit from final class "NBitBase"')
        super().__init_subclass__()

@final
@set_module("numpy._typing")
# Silence errors about subclassing a `@final`-decorated class
class _128Bit(NBitBase):  # type: ignore[misc]  # pyright: ignore[reportGeneralTypeIssues]
    pass

@final
@set_module("numpy._typing")
class _96Bit(_128Bit):  # type: ignore[misc]  # pyright: ignore[reportGeneralTypeIssues]
    pass

@final
@set_module("numpy._typing")
class _64Bit(_96Bit):  # type: ignore[misc]  # pyright: ignore[reportGeneralTypeIssues]
    pass

@final
@set_module("numpy._typing")
class _32Bit(_64Bit):  # type: ignore[misc]  # pyright: ignore[reportGeneralTypeIssues]
    pass

@final
@set_module("numpy._typing")
class _16Bit(_32Bit):  # type: ignore[misc]  # pyright: ignore[reportGeneralTypeIssues]
    pass

@final
@set_module("numpy._typing")
class _8Bit(_16Bit):  # type: ignore[misc]  # pyright: ignore[reportGeneralTypeIssues]
    pass