# -*- coding: utf-8 -*-
import abc
import collections
from functools import lru_cache
from typing import Union
import numpy as np
from .. import _config
from .._imperative_rt.core2 import (
Tensor,
apply,
astype_cpp,
batched_matmul_cpp,
broadcast_cpp,
expand_dims_cpp,
getitem_cpp,
matmul_cpp,
reshape_cpp,
setitem_cpp,
squeeze_cpp,
transpose_cpp,
)
from ..ops import builtin
from . import amp
from .utils import (
_normalize_axis,
astensor1d,
cast_tensors,
convert_inputs,
make_shape_tuple,
subgraph,
)
_ElwMod = builtin.Elemwise.Mode
def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
(result,) = apply(op, *args)
return result
def _elwise_apply(args, mode):
op = builtin.Elemwise(mode)
(result,) = apply(op, *args)
return result
def _elwise(*args, mode):
return _elwise_apply(args, mode)
class _Hashable:
def __init__(self, value) -> None:
self.value = value
def __hash__(self) -> int:
return hash(str(self.value))
def __eq__(self, o: object) -> bool:
if not isinstance(o, _Hashable):
return False
return self.value == o.value
def _matmul(inp1, inp2, transpose_a=False, transpose_b=False, compute_mode="default"):
dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result
elif maxdim <= 2 or (dim2 <= 2 and not transpose_a): # dispath to MatrixMul
# 2x1
# 1x2
# 2x2
# nx1(transpose_a=False), n>=3
# nx2(transpose_a=False), n>=3
ret = matmul_cpp(
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0),
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1),
max(dim1, 2),
max(dim2, 2),
transpose_a,
transpose_b,
compute_mode,
_config._benchmark_kernel,
_config._deterministic_kernel,
)
if dim1 == 1:
ret = squeeze_cpp(ret, -2)
elif dim2 == 1:
ret = squeeze_cpp(ret, -1)
return ret
else: # dispath to BatchedMatrixMul
# nx1(transpose_a=True), n>=3
# nx2(transpose_a=True), n>=3
# nxm,n>=3,m>=3
# 1xm,m>=3
# 2xm,m>=3
ret = batched_matmul_cpp(
inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0),
inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1),
max(dim1, 2),
max(dim2, 2),
transpose_a,
transpose_b,
compute_mode,
_config._benchmark_kernel,
_config._deterministic_kernel,
)
if dim1 == 1:
ret = squeeze_cpp(ret, -2)
elif dim2 == 1:
ret = squeeze_cpp(ret, -1)
return ret
def _unary_elwise(mode):
def f(self):
return _elwise(self, mode=mode)
return f
def _binary_elwise(mode, rev=False):
if not rev:
def f(self, value):
return _elwise(self, value, mode=mode)
else:
def f(self, value):
return _elwise(value, self, mode=mode)
return f
def _logical_unary_elwise(mode, rev=False):
def f(self):
if self.dtype != np.bool_:
raise TypeError("{} requires a bool tensor".format(mode))
return _elwise(self, mode=mode)
return f
def _logical_binary_elwise(mode, rev=False):
if not rev:
def f(self, value):
if self.dtype != np.bool_ or value.dtype != np.bool_:
raise TypeError("{} requires 2 bool tensors".format(mode))
return _elwise(self, value, mode=mode)
else:
def f(self, value):
if self.dtype != np.bool_ or value.dtype != np.bool_:
raise TypeError("{} requires 2 bool tensors".format(mode))
return _elwise(value, self, mode=mode)
return f
def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
if axis is None:
assert not keepdims, "can not set axis=None and keepdims=True"
(result,) = apply(builtin.Reduce(mode=mode), data)
elif isinstance(axis, collections.abc.Iterable):
axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai, keepdim=keepdims)
(data,) = apply(op, data)
result = data
else:
# builtin.Reduce already accept negtive axis
op = builtin.Reduce(mode=mode, axis=axis, keepdim=keepdims)
(result,) = apply(op, data)
return result
return f
def _inplace(f):
def g(self, value):
result = f(self, value)
if result is NotImplemented:
raise NotImplementedError
self._reset(result)
return self
return g
def _todo(*_):
raise NotImplementedError
def _expand_args(args):
if len(args) == 1:
if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),):
args = args[0]
return args
class ArrayMethodMixin(abc.ABC):
# enable tensor to be converted to numpy array
__array_priority__ = 1001
def __array__(self, dtype=None):
if dtype == None:
return self.numpy()
return self.numpy().astype(dtype)
def __array_wrap__(self, array):
Wrapper = type(self)
return Wrapper(array, dtype=array.dtype, device=self.device)
@abc.abstractmethod
def _reset(self, other):
pass
@abc.abstractproperty
def dtype(self) -> np.dtype:
pass
@abc.abstractproperty
def shape(self) -> Union[tuple, Tensor]:
pass
@abc.abstractproperty
def _tuple_shape(self) -> tuple:
pass
@abc.abstractmethod
def numpy(self) -> np.ndarray:
pass
__hash__ = None # due to __eq__ diviates from python convention
__lt__ = lambda self, value: _elemwise_multi_type(
self, value, mode="lt", dtype="bool"
)
__le__ = lambda self, value: _elemwise_multi_type(
self, value, mode="leq", dtype="bool"
)
__gt__ = lambda self, value: _elemwise_multi_type(
value, self, mode="lt", dtype="bool"
)
__ge__ = lambda self, value: _elemwise_multi_type(
value, self, mode="leq", dtype="bool"
)
__eq__ = lambda self, value: _elemwise_multi_type(
self, value, mode="eq", dtype="bool"
)
__ne__ = lambda self, value: _elemwise_multi_type(
self, value, mode="neq", dtype="bool"
)
__neg__ = _unary_elwise(_ElwMod.NEGATE)
__pos__ = lambda self: self
__abs__ = _unary_elwise(_ElwMod.ABS)
__invert__ = _logical_unary_elwise(_ElwMod.NOT)
__round__ = _unary_elwise(_ElwMod.ROUND)
__trunc__ = _todo
__floor__ = _unary_elwise(_ElwMod.FLOOR)
__ceil__ = _unary_elwise(_ElwMod.CEIL)
__add__ = _binary_elwise(_ElwMod.ADD)
__sub__ = _binary_elwise(_ElwMod.SUB)
__mul__ = _binary_elwise(_ElwMod.MUL)
__matmul__ = lambda self, other: _matmul(self, other)
__truediv__ = _binary_elwise(_ElwMod.TRUE_DIV)
__floordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV)
__mod__ = _binary_elwise(_ElwMod.MOD)
# __divmode__
__pow__ = _binary_elwise(_ElwMod.POW)
__lshift__ = _binary_elwise(_ElwMod.SHL)
__rshift__ = _binary_elwise(_ElwMod.SHR)
__and__ = _logical_binary_elwise(_ElwMod.AND)
__or__ = _logical_binary_elwise(_ElwMod.OR)
__xor__ = _logical_binary_elwise(_ElwMod.XOR)
__radd__ = _binary_elwise(_ElwMod.ADD, rev=1)
__rsub__ = _binary_elwise(_ElwMod.SUB, rev=1)
__rmul__ = _binary_elwise(_ElwMod.MUL, rev=1)
__rmatmul__ = lambda self, other: _matmul(other, self)
__rtruediv__ = _binary_elwise(_ElwMod.TRUE_DIV, rev=1)
__rfloordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV, rev=1)
__rmod__ = _binary_elwise(_ElwMod.MOD, rev=1)
# __rdivmode__
__rpow__ = _binary_elwise(_ElwMod.POW, rev=1)
__rlshift__ = _binary_elwise(_ElwMod.SHL, rev=1)
__rrshift__ = _binary_elwise(_ElwMod.SHR, rev=1)
__rand__ = _logical_binary_elwise(_ElwMod.AND, rev=1)
__ror__ = _logical_binary_elwise(_ElwMod.OR, rev=1)
__rxor__ = _logical_binary_elwise(_ElwMod.XOR, rev=1)
__iadd__ = _inplace(__add__)
__isub__ = _inplace(__sub__)
__imul__ = _inplace(__mul__)
__imatmul__ = _inplace(__matmul__)
__itruediv__ = _inplace(__truediv__)
__ifloordiv__ = _inplace(__floordiv__)
__imod__ = _inplace(__mod__)
__ipow__ = _inplace(__pow__)
__ilshift__ = _inplace(__lshift__)
__irshift__ = _inplace(__rshift__)
__iand__ = _inplace(__and__)
__ior__ = _inplace(__or__)
__ixor__ = _inplace(__xor__)
__index__ = lambda self: self.item().__index__()
__bool__ = lambda self: bool(self.item())
__int__ = lambda self: int(self.item())
__float__ = lambda self: float(self.item())
__complex__ = lambda self: complex(self.item())
def __len__(self):
shape = self._tuple_shape
if shape:
return int(shape[0])
raise TypeError("ndim is 0")
def __iter__(self):
for i in range(len(self)):
yield self[i]
def __getitem__(self, index):
return getitem_cpp(self, index)
def __setitem__(self, index, value):
if index is not Ellipsis:
value = setitem_cpp(self, index, value)
self._reset(value)
__contains__ = _todo
@property
def ndim(self):
r"""Returns the number of dimensions of self :class:`~.Tensor`."""
shape = self._tuple_shape
if shape is None:
raise ValueError("unkown ndim")
return len(shape)
@property
def size(self):
r"""Returns the size of the self :class:`~.Tensor`.
The returned value is a subclass of :class:`tuple`.
"""
shape = self.shape
if shape.__class__ is tuple:
return np.prod(self.shape).item()
return shape.prod()
@property
def T(self):
r"""alias of :attr:`~.Tensor.transpose`."""
return self.transpose()
def item(self, *args):
r"""Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`.
This only works for tensors with one element. For other cases, see :meth:`~.tolist`.
"""
if not args:
if isinstance(self.size, int):
assert self.size == 1
return self.numpy().item()
return self[args].item()
def tolist(self):
r"""Returns the tensor as a (nested) list.
For scalars, a standard Python number is returned, just like with :meth:`~.item`.
Tensors are automatically moved to the CPU first if necessary.
This operation is not differentiable.
"""
return self.numpy().tolist()
def astype(self, dtype):
r"""Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`.
"""
return astype_cpp(self, dtype)
def reshape(self, *args):
r"""See :func:`~.reshape`."""
return reshape_cpp(self, args)
# FIXME: remove this method
def _broadcast(self, *args):
return broadcast_cpp(self, args)
def transpose(self, *args):
r"""See :func:`~.transpose`."""
return transpose_cpp(self, args)
def flatten(self, start_axis: int = 0, end_axis: int = -1):
r"""See :func:`~.flatten`."""
inp_shape = self.shape
if start_axis < 0:
start_axis += len(inp_shape)
target_shape = tuple(inp_shape[i] for i in range(start_axis)) + (-1,)
if end_axis != -1:
target_shape += (*inp_shape[end_axis + 1 :],)
return reshape_cpp(self, target_shape)
def sum(self, axis=None, keepdims: bool = False):
r"""See :func:`~.sum`."""
return _reduce("sum")(self, axis, keepdims)
def prod(self, axis=None, keepdims: bool = False):
r"""See :func:`~.prod`."""
return _reduce("product")(self, axis, keepdims)
def min(self, axis=None, keepdims: bool = False):
r"""See :func:`~.min`."""
return _reduce("min")(self, axis, keepdims)
def max(self, axis=None, keepdims: bool = False):
r"""See :func:`~.max`."""
return _reduce("max")(self, axis, keepdims)
def mean(self, axis=None, keepdims: bool = False):
r"""See :func:`~.mean`."""
return _reduce("mean")(self, axis, keepdims)