megengine.functional.nn.multi_head_attention¶
- multi_head_attention(query, key, value, embed_dim, num_heads, attn_drop, out_drop, io_weight_bias, qproj_size=None, kproj_size=None, vproj_size=None, oproj_size=None, qbias=False, kbias=False, vbias=False, obias=False, bias_k=None, bias_v=None, add_zero_attn=False, key_padding_mask=None, attn_mask=None, need_weights=False, average_attn_weights=False, is_causal=False, maybe_cudnn_style_mask=False, reslink=False, training=True)[source]¶
Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need.
\[\text{MultiHeadAttn}\big(q, k, v, W_Q, W_K, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i\]where \(h_i=W_{V,i}v \text{Softmax}\Big( \text{smScaler} \cdot k^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}\).
See
MultiHeadAttn
for more details.Note: This API is experimental, and there is a possibility of subsequent changes.
- Parameters
query (
Tensor
) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.key (
Tensor
) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.value (
Tensor
) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.embed_dim (
int
) – total dimension of the model.num_heads (
int
) – parallel attention heads.attn_drop (
float
) – probability of an element to be zeroed, used in attention matrix.out_drop (
float
) – probability of an element to be zeroed, used in final output.io_weight_bias (
Optional
[Tensor
]) – input/output projection weight/bias all in one. The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias. Note: \(Y=X@W+B\) is used here instead of \(Y=X@W^T+B\) in pytorch.qproj_size (
Optional
[int
]) – indicates the projection size of query weight in io_weight_bias, 0 indicates disabled query projection and no query projection weight.kproj_size (
Optional
[int
]) – indicates the projection size of key weight in io_weight_bias, 0 indicates disabled key projection and no key projection weight.vproj_size (
Optional
[int
]) – indicates the projection size of value weight in io_weight_bias, 0 indicates disabled value projection and no value projection weight.oproj_size (
Optional
[int
]) – indicates the projection size of out weight in io_weight_bias, 0 indicates disabled output projection and no output projection weight.qbias (
bool
) – indicates whether there is a query bias in io_weight_bias, this parameter is only valid when qproj_size > 0.kbias (
bool
) – indicates whether there is a key bias in io_weight_bias, this parameter is only valid when kproj_size > 0.vbias (
bool
) – indicates whether there is a value bias in io_weight_bias, this parameter is only valid when vproj_size > 0.obias (
bool
) – indicates whether there is a out bias in io_weight_bias, this parameter is only valid when oproj_size > 0.bias_k (
Optional
[Tensor
]) – the bias of the key and value sequences to be added at sequence dim. distinguished from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix. Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.bias_v (
Optional
[Tensor
]) – the bias of the key and value sequences to be added at sequence dim. distinguished from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix. Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.add_zero_attn (
bool
) – if specified, adds a new batch of zeros to the key and value sequences at sequence dim. Default:False
. Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.key_padding_mask (
Optional
[Tensor
]) – if specified, a mask of shape \((N, S)\) indicating which elements withinkey
to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be \((S)\). Binary and float masks are supported. For a binary mask, aTrue
value indicates that the correspondingkey
value will be ignored for the purpose of attention. For a float mask, it will be directly added to the correspondingkey
value.attn_mask (
Optional
[Tensor
]) – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.need_weights (
bool
) – indicates whether to return the attention weight, which is the output result of softmax. Default: Falseaverage_attn_weights (
bool
) – if true, indicates that the returnedattn_weights
should be averaged across heads. Otherwise,attn_weights
are provided separately per head. Note that this flag only has an effect whenneed_weights=True
. Default:False
(i.e. average weights across heads)is_causal (
bool
) – if specified, applies a causal mask as attention mask. Default:False
Warning:is_causal
provides a hint thatattn_mask
is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.maybe_cudnn_style_mask (
bool
) – if specified, applies a cudnn style mask as attention mask. Default:False
Note: In the cudnn style, the shape of the attn_mask is \((2, L)\), and the shape of the key_padding_mask is \((2, N)\). Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the_merge_masks
function returnsmerge_type=cudnn_style_mask
, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.reslink (
bool
) – add input query to final output. Note: It is only valid if the input query is the same as the shape of the output. Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.training (
bool
) – will apply dropout if isTrue
.
- Outputs:
out[0]=attn_output - Attention outputs of shape \((N, L, E)\), where \(L\) is the target sequence length, \(N\) is the batch size, and \(E\) is the embedding dimension
embed_dim
.out[1]=attn_output_weights - Only returned when
need_weights=True
. Ifaverage_attn_weights=True
, returns attention weights averaged across heads of shape \((L, S)\) when input is unbatched or \((N, L, S)\), where \(N\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. Ifaverage_attn_weights=False
, returns attention weights per head of shape \((\text{num\_heads}, L, S)\) when input is unbatched or \((N * \text{num\_heads}, L, S)\).out[2]=mask_reversespace - Used to save the dropout mask needed for backward propagation.,
out[3]=othr_reversespace - Used to save the intermediate results that need to be used in backward propagation.,