megengine.functional.nn.logsumexp

logsumexp(inp, axis, keepdims=False)[source]

Calculates the logarithm of the inputs’ exponential sum along the given axis.

\[\text{logsumexp}(x)= \log \sum_{j=1}^{\]

n} exp left(x_{ j}right)

For numerical stability, the implementation follows this transformation:

\[\text{logsumexp}(x)= \log \sum_{j=1}^{\]

n} exp left(x_{ j}right)

= text{logsumexp}(x)=b+log sum_{j=1}^{

n} exp left(x_{j}-bright)

where

\[b = \max(x_j)\]

Examples

>>> import numpy as np
>>> x = Tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
>>> y = F.logsumexp(x, axis=1, keepdims=False)
>>> y.numpy().round(decimals=4)
array([-0.5481,  4.4519], dtype=float32)
Return type

Tensor