其他分享
首页 > 其他分享> > [MHA attention mask]之因果causal/因果加backtrace/前后向N帧 mask

[MHA attention mask]之因果causal/因果加backtrace/前后向N帧 mask

作者:互联网

文章目录

在multihead attention 中可添加attention mask,对输入进行范围限定,如

1. Attention Mask or Causal Mask

可指定causal参数,来生成普通的attention mask 还是causal mask:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from tensorflow.keras.layers import Layer, Masking
import tensorflow as tf

class AttentionMask(Layer):
    """
	Computes attention mask.
	"""

    def __init__(self, causal, mask_value=-1e9):
        """
        Argument/s:
			causal - causal attention mask flag.
			mask_value - value used to mask components that aren't to be attended
				to (typically -1e9).
        """
        super(AttentionMask, self).__init__()
        self.causal = causal
        self.mask_value = mask_value
        if not isinstance(mask_value, float): raise ValueError("Mask value must be a float.")

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        causal_mask = self.lower_triangular_mask([1, max_seq_len, max_seq_len]) if self.causal else None
        ################
        logical_mask = self.merge_masks(causal_mask, seq_mask)
        unmasked = tf.zeros([batch_size, max_seq_len, max_seq_len])
        masked = tf.fill([batch_size, max_seq_len, max_seq_len], self.mask_value)
        att_mask = tf.where(logical_mask, unmasked, masked)
        seq_mask = tf.cast(seq_mask, tf.float32)
        return att_mask, seq_mask

    def lower_triangular_mask(self, shape):
        """
		Creates a lower-triangular boolean mask over the last 2 dimensions.

		Argument/s:
			shape - shape of mask.

		Returns:
			causal mask.
		"""
        row_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)
        col_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)
        return tf.math.greater_equal(row_index, col_index)

    def merge_masks(self, x, y):
        """
		Merges a sequence mask and a causal mask to make an attantion mask.

		Argument/s:
			x - mask.
			y - mask.

		Returns:
			Attention mask.
		"""
        if x is None: return y
        if y is None: return x
        return tf.math.logical_and(x, y)

2. Causal Mask (with n_backtrce)

即带有n_backtrce的因果mask,继承上面的AttentionMask:

class AttentionMask_Causal_Backtrace(AttentionMask):
    """
	Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention.
	"""

    def __init__(self, causal, n_backtrace=None):
        """
		Argument/s:
			causal - causal attention mask flag.
			n_backtrace - (int) number of backtrace
		"""
        super().__init__(causal)
        self.causal = causal
        self.n_backtrace = n_backtrace

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        causal_mask = self.lower_triangular_mask([batch_size, max_seq_len, max_seq_len]) if self.causal else None
        bt_mask = self.backtrace_mask([1, max_seq_len, max_seq_len]) \
            if self.causal and self.n_backtrace else None
        ################
        logical_mask = self.merge_masks(causal_mask, seq_mask)
        logical_mask = self.merge_masks(logical_mask, bt_mask)
        att_mask = tf.cast(logical_mask, tf.float32)
        att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])
        return att_mask

    def backtrace_mask(self, shape):
        """
		Creates a lower-triangular boolean mask over the last 2 dimensions.

		Argument/s:
			shape - shape of mask.

		Returns:
			causal mask.
		"""
        row_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)
        col_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)
        return tf.math.less_equal(row_index, col_index + self.n_backtrace)

3. Attention Mask with backstrace and forwardtrace

class AttentionMask_Backtrace_Forwardtrace(AttentionMask):
    """
	Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention.
	"""

    def __init__(self, causal, n_backtrace=None, n_forwardtrace=None):
        """
		Argument/s:
			causal - causal attention mask flag.
			n_backtrace - (int) number of backtrace
		"""
        super().__init__(causal)
        self.causal = causal
        self.n_backtrace = n_backtrace
        self.n_forwardtrace = n_forwardtrace

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        causal_mask = self.lower_triangular_mask([batch_size, max_seq_len, max_seq_len]) if self.causal else None
        bt_ft_mask = self.backtrace_forwardtrace_mask([1, max_seq_len, max_seq_len]) \
            if self.causal and self.n_backtrace and self.n_forwardtrace else None
        ################
        logical_mask = self.merge_masks(causal_mask, seq_mask)
        logical_mask = self.merge_masks(logical_mask, bt_mask)
        att_mask = tf.cast(logical_mask, tf.float32)
        att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])
        return att_mask

    def backtrace_forwardtrace_mask(self, shape):
        """
		Creates a lower-triangular boolean mask over the last 2 dimensions.

		Argument/s:
			shape - shape of mask.

		Returns:
			causal mask.
		"""
        row_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)
        col_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)
        bt_mask = tf.math.less_equal(row_index, col_index + self.n_backtrace)
        ft_mask = tf.math.greater_equal(row_index + self.n_forwardtrace, col_index)
        bt_ft_mask = self.merge_masks(bt_mask, ft_mask)
        return bt_ft_mask

4. Customized Mask

class AttentionMask_Customization(AttentionMask):
    """
	Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention.
	"""

    def __init__(self, causal, trace=None):
        """
		Argument/s:
			causal - causal attention mask flag.
			n_backtrace - (int) number of backtrace
		"""
        super().__init__(causal)
        self.causal = causal
        self.trace = trace

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        customized_mask = self.customized_mask(batch_size, max_seq_len, self.trace)
        ################
        logical_mask = self.merge_masks(customized_mask, seq_mask)
        att_mask = tf.cast(logical_mask, tf.float32)
        att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])
        return att_mask

    @tf.function
    def customized_mask(self, batchsize, max_length, trace):
        mask = tf.ones(shape=[batchsize, trace, trace], dtype=tf.int32, name="row")
        shape_pad = int(max_length - trace)
        mask = tf.pad(mask, paddings=[[0, 0], [shape_pad, 0], [shape_pad, 0]])
        mask = tf.cast(mask, dtype=bool)
        return mask

标签:seq,backtrace,self,mask,shape,tf,causal,因果
来源: https://blog.csdn.net/u010637291/article/details/121438116