# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import numpy as np
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import convert_inputs
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
__all__ = [
"abs",
"add",
"acos",
"asin",
"atan",
"atan2",
"asinh",
"acosh",
"atanh",
"ceil",
"clip",
"cos",
"cosh",
"div",
"equal",
"exp",
"expm1",
"floor",
"floor_div",
"greater",
"greater_equal",
"left_shift",
"less",
"less_equal",
"log",
"log1p",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"maximum",
"minimum",
"mod",
"mul",
"neg",
"not_equal",
"pow",
"right_shift",
"round",
"sin",
"sinh",
"sqrt",
"square",
"sub",
"tan",
"tanh",
]
def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = convert_inputs(*args)
(result,) = apply(op, *args)
return result
# math operations
[文档]def add(x, y):
r"""Element-wise `addition`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.add(x, y)
print(out.numpy())
Outputs:
.. testoutput::
[[ 0. 2. 4.]
[ 6. 8. 10.]]
"""
return _elwise(x, y, mode=Elemwise.Mode.ADD)
[文档]def sub(x, y):
r"""Element-wise `sub`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.sub(x, y)
print(out.numpy())
Outputs:
.. testoutput::
[[1. 1. 1.]
[1. 1. 1.]]
"""
return _elwise(x, y, mode=Elemwise.Mode.SUB)
[文档]def mul(x, y):
r"""Element-wise `multiplication`."""
return _elwise(x, y, mode=Elemwise.Mode.MUL)
[文档]def div(x, y):
r"""Element-wise `(x / y)`."""
return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV)
[文档]def floor_div(x, y):
r"""Element-wise `floor(x / y)`."""
return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV)
[文档]def neg(x):
r"""Element-wise `negation`."""
return _elwise(x, mode=Elemwise.Mode.NEGATE)
[文档]def pow(x, y):
r"""Element-wise `power`."""
return _elwise(x, y, mode=Elemwise.Mode.POW)
[文档]def mod(x, y):
r"""Element-wise `remainder of division`."""
return _elwise(x, y, mode=Elemwise.Mode.MOD)
[文档]def abs(x):
r"""Element-wise `absolute value`."""
return _elwise(x, mode=Elemwise.Mode.ABS)
[文档]def exp(x):
r"""Element-wise `exponential`."""
return _elwise(x, mode=Elemwise.Mode.EXP)
[文档]def expm1(x):
r"""Element-wise `exp(x)-1`."""
return _elwise(x, mode=Elemwise.Mode.EXPM1)
[文档]def log(x):
r"""Element-wise `logarithm (base e)`."""
return _elwise(x, mode=Elemwise.Mode.LOG)
[文档]def log1p(x):
r"""Element-wise `log(x+1) (base e)`."""
return _elwise(x, mode=Elemwise.Mode.LOG1P)
[文档]def sqrt(x: Tensor) -> Tensor:
r"""Element-wise `sqrt`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.sqrt(x)
print(out.numpy().round(decimals=4))
Outputs:
.. testoutput::
[[0. 1. 1.4142]
[1.7321 2. 2.2361]]
"""
return x ** 0.5
[文档]def square(x: Tensor) -> Tensor:
r"""Element-wise `square`.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.functional as F
data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.square(data)
print(out.numpy().round(decimals=4))
Outputs:
.. testoutput::
[[ 0. 1. 4.]
[ 9. 16. 25.]]
"""
return x ** 2
[文档]def round(x):
r"""Element-wise `rounding to int`."""
return _elwise(x, mode=Elemwise.Mode.ROUND)
[文档]def ceil(x):
r"""Element-wise `ceiling`."""
return _elwise(x, mode=Elemwise.Mode.CEIL)
[文档]def floor(x):
r"""Element-wise `floor`."""
return _elwise(x, mode=Elemwise.Mode.FLOOR)
[文档]def maximum(x, y):
r"""Element-wise `maximum of array elements`."""
return _elwise(x, y, mode=Elemwise.Mode.MAX)
[文档]def minimum(x, y):
r"""Element-wise `minimum of array elements`."""
return _elwise(x, y, mode=Elemwise.Mode.MIN)
# trigonometric functions
[文档]def cos(x):
r"""Element-wise `cosine`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.cos(x)
print(out.numpy().round(decimals=4))
Outputs:
.. testoutput::
[[ 1. 0.5403 -0.4161]
[-0.99 -0.6536 0.2837]]
"""
return _elwise(x, mode=Elemwise.Mode.COS)
[文档]def sin(x):
r"""Element-wise `sine`."""
return _elwise(x, mode=Elemwise.Mode.SIN)
[文档]def tan(x):
r"""Element-wise `tangent`."""
return sin(x) / cos(x)
[文档]def acos(x):
r"""Element-wise `inverse cosine`."""
return _elwise(x, mode=Elemwise.Mode.ACOS)
[文档]def asin(x):
r"""Element-wise `inverse sine`."""
return _elwise(x, mode=Elemwise.Mode.ASIN)
[文档]def atan(x):
r"""Element-wise `inverse tangent`."""
return _elwise(x, 1, mode=Elemwise.Mode.ATAN2)
[文档]def atan2(y, x):
r"""Element-wise `2-argument arctangent`."""
return _elwise(y, x, mode=Elemwise.Mode.ATAN2)
[文档]def cosh(x):
r"""Element-wise `hyperbolic cosine`."""
return 0.5 * (exp(x) + exp(-x))
[文档]def sinh(x):
r"""Element-wise `hyperbolic sine`."""
u = expm1(x)
return 0.5 * u / (u + 1) * (u + 2)
[文档]def tanh(x):
r"""Element-wise `hyperbolic tangent`."""
return _elwise(x, mode=Elemwise.Mode.TANH)
[文档]def asinh(x):
r"""Element-wise `inverse hyperbolic sine`."""
return log(x + (x ** 2 + 1) ** 0.5)
[文档]def acosh(x):
r"""Element-wise `inverse hyperbolic cosine`."""
return log(x + (x ** 2 - 1) ** 0.5)
[文档]def atanh(x):
r"""Element-wise `inverse hyperbolic tangent`."""
return log1p(2 * x / (1 - x)) / 2
# bit-twiddling functions
[文档]def left_shift(x, y):
r"""Element-wise `bitwise binary: x << y`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3))
out = F.left_shift(x, 2)
print(out.numpy())
Outputs:
.. testoutput::
[[ 0 4 8]
[12 16 20]]
"""
return _elwise(x, y, mode=Elemwise.Mode.SHL)
[文档]def right_shift(x, y):
r"""Element-wise `bitwise binary: x >> y`."""
return _elwise(x, y, mode=Elemwise.Mode.SHR)
# logical functions
[文档]def logical_and(x, y):
r"""Element-wise `logical and: x && y`."""
return _elwise(x, y, mode=Elemwise.Mode.AND)
[文档]def logical_not(x):
r"""Element-wise `logical not: ~x`."""
return _elwise(x, mode=Elemwise.Mode.NOT)
[文档]def logical_or(x, y):
r"""Element-wise `logical or: x || y`."""
return _elwise(x, y, mode=Elemwise.Mode.OR)
[文档]def logical_xor(x, y):
r"""Element-wise `logical xor: x ^ y`."""
return _elwise(x, y, mode=Elemwise.Mode.XOR)
# comparison functions
[文档]def equal(x, y):
r"""Element-wise `(x == y)`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.equal(x, y)
print(out.numpy())
Outputs:
.. testoutput::
[[1. 1. 1.]
[1. 1. 1.]]
"""
return _elwise(x, y, mode=Elemwise.Mode.EQ)
[文档]def not_equal(x, y):
r"""Element-wise `(x != y)`."""
return x != y
[文档]def less(x, y):
r"""Element-wise `(x < y)`."""
return _elwise(x, y, mode=Elemwise.Mode.LT)
[文档]def less_equal(x, y):
r"""Element-wise `(x <= y)`."""
return _elwise(x, y, mode=Elemwise.Mode.LEQ)
[文档]def greater(x, y):
r"""Element-wise `(x > y)`."""
return _elwise(y, x, mode=Elemwise.Mode.LT)
[文档]def greater_equal(x, y):
r"""Element-wise `(x >= y)`."""
return _elwise(y, x, mode=Elemwise.Mode.LEQ)
# other functions
[文档]def clip(x: Tensor, lower=None, upper=None) -> Tensor:
r"""Clamps all elements in input tensor into the range ``[ lower, upper ]`` and returns
a resulting tensor:
.. math::
y_i = \begin{cases}
\text{lower} & \text{if } x_i < \text{lower} \\
x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
\text{upper} & \text{if } x_i > \text{upper}
\end{cases}
Args:
x: input tensor.
lower: lower-bound of the range to be clamped to.
upper: upper-bound of the range to be clamped to.
Returns:
output clamped tensor.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
a = tensor(np.arange(5).astype(np.int32))
print(F.clip(a, 2, 4).numpy())
print(F.clip(a, lower=3).numpy())
print(F.clip(a, upper=3).numpy())
Outputs:
.. testoutput::
[2 2 2 3 4]
[3 3 3 3 4]
[0 1 2 3 3]
"""
assert (
lower is not None or upper is not None
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
return minimum(maximum(x, lower), upper)
else:
return maximum(x, lower)
else:
return minimum(x, upper)
sigmoid = deprecated_func("1.3", "megengine.functional.nn", "sigmoid", True)
hsigmoid = deprecated_func("1.3", "megengine.functional.nn", "hsigmoid", True)
relu = deprecated_func("1.3", "megengine.functional.nn", "relu", True)
relu6 = deprecated_func("1.3", "megengine.functional.nn", "relu6", True)
hswish = deprecated_func("1.3", "megengine.functional.nn", "hswish", True)