Code Review for CGBERT (Context-Guided BERT)

참고 : https://github.com/frankaging/Quasi-Attention-ABSA/blob/main/code/model/CGBERT.py


1. Architecture

figure2

ContextBertModel

  • (1) BERTEmbedding

  • (2) Context BERTEncoder x N

    • 2-1) Context BERT Attention
      • a) Context BERT Self-Attention
      • b) Context BERT Self-Output
    • 2-2) BERT Intermediate
    • 2-3) BERT Output


2. Class & Function introduction

1) Main Modules

Model : ContextBertModel

  • BERT Embedding : BERTEmbedding

  • Encoder : BERTEncoder

    • Bert layer : BERTLayer
      • Attention : BERTAttention
        • Context Self Attention : BERTSelfAttention
        • Self Attention output : BERTSelfOutput
      • Intermediate layer : BERTIntermediate
      • Output layer : BERTOutput


2) Functions / Other Classes

  • GeLU : gelu
  • Layer Normalization : BERTLayerNorm
  • Pooler : BERTPooler


3. Code Review (with Pytorch)

3-1) Main Modules

ContextBertModel

전체적인 CGBERT의 알고리즘

( (큰 흐름에서) BERT와의 유일한 차이점 : Step3 )

  • Step 1) attention_mask : 패딩 마스크 생성

    ( 문장별로 단어 길이 다른것을 감안해주기 위함 )

    figure2

  • Step 2) 세 종류의 embedding을 더함

    ( 세 종류의 embedding은 아래에 구체적으로 설명 )

  • Step 3) Context를 embedding함 ( 8차원 \(\rightarrow\) hidden_size 차원)

  • Step 4) \(L\)개의 Encoder Layer를 통과

    • 그 안에 Self Attention / FFNN 등으로 구성
  • Step 5) Pooling

class ContextBertModel(nn.Module):
    def __init__(self, config: BertConfig):
        super(ContextBertModel, self).__init__()
        self.embeddings = BERTEmbeddings(config)
        self.encoder = ContextBERTEncoder(config)
        self.pooler = ContextBERTPooler(config)
        self.context_embeddings = nn.Embedding(2*4, config.hidden_size)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,device=None,ontext_ids=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        
	    # Step 1) attention mask 생성
        ## 3D attention mask 생성 ( from 2D attention mask )
        ## mask의 size = [배치 크기, 1, 1, 문장 길이]
        mask3d = attention_mask.unsqueeze(1).unsqueeze(2).float()
        mask3d = (1.0 - mask3d) * -10000.0

        # Step 2) 세 종류의 embedding을 더함
        embedded = self.embeddings(input_ids, token_type_ids)

        #-----------------------------------------------#
	    # Step 3) Context를 Embedding함
        #-----------------------------------------------#
        seq_len = embedding_output.shape[1]
        C_embed = self.context_embeddings(context_ids).squeeze(dim=1)
        embedded_c = torch.stack(seq_len*[C_embed], dim=1)

        # Step 4) L개의 (Transformer) Encoder layer 를 거침
        ## 기존과의 차이점) 단어 embedding & context embedding 둘 다 encoding됨
        all_encoder_layers = self.encoder(embedded, mask3d,device,embedded_c)
        
	    # Step 5) (맨 마지막 L번째 layer output 제외하고) Pooling
        layers_wo_last = all_encoder_layers[-1]
        pooled_output = self.pooler(layers_wo_last, attention_mask)
        return pooled_output


BERT Embedding

figure2

( BERT와 동일 )

class BERTEmbeddings(nn.Module):
    def __init__(self, config):
        super(BERTEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = BERTLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        # (1) Position embedding을 위한 (문장의 token 길이만큼의) index 생성
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
		
        # (2) 세 종류의 Embedding을 더함
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        
        # (3) Layer Normalization + Dropout
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


Context BERT Encoder

figure2

( Context를 Encoding하는 layer가 추가된 것 외에는 전부 동일 )

  • deep_context_transform_layer 가 추가됨
class ContextBERTEncoder(nn.Module):
    def __init__(self, config):
        super(ContextBERTEncoder, self).__init__()
        # (1) 일반 Layer
        layer = ContextBERTLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])    
        
	    # (2) Context Layer
        layer_c = nn.Linear(2*config.hidden_size, config.hidden_size)
        self.context_layer = nn.ModuleList([copy.deepcopy(layer_c) for _ in range(config.num_hidden_layers)])  

    def forward(self, hidden_states, attention_mask,
                device=None, context_embeddings=None):
        all_encoder_layers = []
        layer_index = 0
        for layer_module in self.layer:
            
            # update context
            ## 출력된 C_embed의 크기 : (batch_size, seq_len, d_hidden).
            deep_context_hidden = torch.cat([context_embeddings, hidden_states], dim=-1)
            deep_context_hidden = self.context_layer[layer_index](deep_context_hidden)
            deep_context_hidden += context_embeddings
            
            # BERT encoding
            hidden_states = layer_module(hidden_states, attention_mask,device, deep_context_hidden)
            all_encoder_layers.append(hidden_states)
            layer_index += 1
        return all_encoder_layers


[A] Context BERT Layer

( BERTAttention 대신 ContextBERTAttention을 사용한 사실 외에는 전부 동일 )

class ContextBERTLayer(nn.Module):
    def __init__(self, config):
        super(ContextBERTLayer, self).__init__()
        #-----------------------------------------------#
        self.attention = ContextBERTAttention(config) # (Step 1)
        #-----------------------------------------------#
        self.intermediate = BERTIntermediate(config) # (Step 2)
        self.output = BERTOutput(config) # (Step 3)

    def forward(self, hidden_states, attention_mask,
                device=None, C_embed=None):
        attention_output = self.attention(hidden_states, attention_mask,device, C_embed)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


[A-1] Context BERT Attention

class ContextBERTAttention(nn.Module):
    def __init__(self, config):
        super(ContextBERTAttention, self).__init__()
        #-----------------------------------------------#
        self.self = ContextBERTSelfAttention(config) # (Step 1)
        #-----------------------------------------------#
        self.output = BERTSelfOutput(config) # (Step 2)

    def forward(self, input_tensor, attention_mask,device=None, C_embed=None):
        self_output = self.self.forward(input_tensor, attention_mask,device, C_embed)
        attention_output = self.output(self_output, input_tensor)
        return attention_output


[A-1-a] Self Attention

[ New attention matrix ] :

  • linear combination of regular softmax attention matrix & quasi-attention matrix
  • (논문) \(\hat{\mathbf{A}}^{h}=\mathbf{A}_{\text {Self-Attn }}^{h}+\lambda_{A}^{h} \mathbf{A}_{\text {Quasi-Attn }}^{h}\).
  • (코드 구현) \(\hat{\mathbf{A}}^{h}=\mathbf{A}_{\text {Quasi-Attn }}^{h}\)


\(\mathbf{A}_{\text {Quasi-Attn }}^{h}\)를 어떻게 구할까?

  • quasi-context query \(\mathrm{C}_{Q}^{h}\)
  • quasi-context key \(\mathbf{C}_{K}^{h}\) 구하기

\(\left[\begin{array}{c} \mathbf{C}_{Q}^{h} \\ \mathbf{C}_{K}^{h} \end{array}\right]=\mathbf{C}^{h}\left[\begin{array}{l} \mathbf{Z}_{Q} \\ \mathbf{Z}_{K} \end{array}\right]\).


Quasi-attention matrix ( \(\mathbf{A}_{\text {Quasi-Attn }}^{h}\) ) :

\(\mathbf{A}_{Q u a s i-\mathrm{Att} n}^{h}=\alpha \cdot \operatorname{sigmoid}\left(\frac{f_{\psi}\left(\mathbf{C}_{Q}^{h}, \mathbf{C}_{K}^{h}\right)}{\sqrt{d_{h}}}\right)\).

  • \(\alpha\) : scaling factor … 1로 사용할 것임
  • \(f_{\psi}(\cdot)\) : similarity measure ( \(Q\) 와 \(V\) 사이의 ) …. dot product로 사용할 것임
  • 따라서, \(\mathbf{A}_{\text {Quasi-Attn }}^{h}\) 는 0~1사이 값


( 아래 내용은 코드에서는 반영되지 않은, Paper 상의 내용 )

그런 뒤 bidirectional gating factor \(\lambda_A\)를 아래와 같이 설정

\(\begin{gathered} {\left[\begin{array}{c} \lambda_{Q}^{h} \\ \lambda_{K}^{h} \end{array}\right]=\operatorname{sigmoid}\left(\left[\begin{array}{c} \mathbf{Q}^{h} \\ \mathbf{K}^{h} \end{array}\right]\left[\begin{array}{c} \mathbf{V}_{Q}^{h} \\ \mathbf{V}_{K}^{h} \end{array}\right]+\left[\begin{array}{c} \mathbf{C}_{Q}^{h} \\ \mathbf{C}_{K}^{h} \end{array}\right]\left[\begin{array}{c} \mathbf{V}_{Q}^{C} \\ \mathbf{V}_{K}^{C} \end{array}\right]\right)} \\ \lambda_{A}^{h}=1-\left(\beta \cdot \lambda_{Q}^{h}+\gamma \cdot \lambda_{K}^{h}\right) \end{gathered}\).

  • \(\lambda_{Q}^{h}\) 와 \(\lambda_{K}^{h}\) 가 head 마다 다를 수 있게 함
  • \(\beta=1\), \(\gamma=1\)로 설정
  • 따라서, \(\lambda_A\)는 0~1사이

  • 최종적인 Attention \(\hat{\mathbf{A}}\)는 -1~2사이에 놓이게 되어있다!


class ContextBERTSelfAttention(nn.Module):
    def __init__(self, config):
        super(ContextBERTSelfAttention, self).__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        
        ## (1) Weight (matrices)
        self.Wq = nn.Linear(config.hidden_size, self.all_head_size)
        self.Wk = nn.Linear(config.hidden_size, self.all_head_size)
        self.Wv = nn.Linear(config.hidden_size, self.all_head_size)     
		self.Wcq = nn.Linear(self.attention_head_size, self.attention_head_size)
        self.Wck = nn.Linear(self.attention_head_size, self.attention_head_size)
        
        ## (2) Lambda (scalar)
        self.lambda_Qc = nn.Linear(self.attention_head_size, 1, bias=False)
        self.lambda_Q = nn.Linear(self.attention_head_size, 1, bias=False)
        self.lambda_Kc = nn.Linear(self.attention_head_size, 1, bias=False)
        self.lambda_K = nn.Linear(self.attention_head_size, 1, bias=False)
		
        ## (3) Dropout & activation
        self.dropout = nn.Dropout(config.A_prob_dropout_prob)
        self.lambda_act = nn.Sigmoid()
        
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask,
                device=None, C_embed=None):
        #-----------------------------#
        Q = self.transpose_for_scores(self.Wq(hidden_states))
        K = self.transpose_for_scores(self.Wk(hidden_states))
        V = self.transpose_for_scores(self.Wv(hidden_states))
        #-----------------------------#
	    Qc = self.Wcq(C_embed)
        Kc = self.Wck(C_embed)
        #-----------------------------#
        
        #----------------------------------------------------------#
        # (주의) context is shared and is the same for every head.
        C_embed = self.transpose_for_scores(C_embed)
        #----------------------------------------------------------#
	    # [1] Q_final = (1-lambda_Q) * Q + lambda_Q * Context_Q
        #----------------------------------------------------------#
        lambda_Q1 = self.lambda_Qc(Qc) # (1) Context 관련 lambda_Q 
        lambda_Q2 = self.lambda_Q(Q)   # (2) 일반 lambda_Q
        lambda_Q = lambda_Q1 + lambda_Q2
        lambda_Q = self.lambda_act(lambda_Q)
        Q_final = (1 - lambda_Q) * Q + lambda_Q * Qc

        #----------------------------------------------------------#
        # [2] K_final = (1-lambda_K) * K + lambda_K * Context_K
        #----------------------------------------------------------#
        lambda_K1 = self.lambda_Kc(Kc) # (1) Context 관련 lambda_Q 
        lambda_K2 = self.lambda_K(K) # (2) 일반 lambda_Q
        lambda_K = lambda_K1 + lambda_K2
        lambda_K = self.lambda_act(lambda_K)
        K_final = (1 - lambda_K) * K + lambda_K * Kc

	    #----------------------------------------------------------#
        # Attention 계산하기 (A)
        #-- 차이점) 여기서 사용하는 Query, Value는
        #-- 일반 Query, Value가 아니라 'CONTEXT가 반영된' Query와 Value이다
        #-- ( 여기서 CONTEXT가 반영되는 비율을 조절하는 parameter가 lambda이다 )
        A_score = torch.matmul(Q_final, K_final.transpose(-1, -2))
        A_score = A_score / math.sqrt(self.attention_head_size)
        A_score = A_score + attention_mask
        A = nn.Softmax(dim=-1)(A_score)
        A = self.dropout(A_prob)
        
        #----------------------------------------------------------#
        ## context layer 계산하기 ( A x V )
        context_layer = torch.matmul(A, V)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer


[A-1-b] Self Attention Output

( BERT와 동일 )

class BERTSelfOutput(nn.Module):
    def __init__(self, config):
        super(BERTSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BERTLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


[A-2] Intermediate Layer

  • ( BERT와 동일 )
class BERTIntermediate(nn.Module):
    def __init__(self, config):
        super(BERTIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = gelu

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


[A-3] Output Layer

( BERT와 동일 )

class BERTOutput(nn.Module):
    def __init__(self, config):
        super(BERTOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BERTLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


3-2) Functions / Other Classes

GeLU

\(\operatorname{GELU}(x)=0.5 x\left(1+\tanh \left(\sqrt{2 / \pi}\left(x+0.044715 x^{3}\right)\right)\right)\).

def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


Layer Normalization

( BERT와 동일 )

  • \(\gamma\) : scaling parameter / \(\beta\) : shift parameter
class BERTLayerNorm(nn.Module):
    def __init__(self, config, variance_epsilon=1e-12):
        super(BERTLayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(config.hidden_size))
        self.beta = nn.Parameter(torch.zeros(config.hidden_size))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        std = (x - mu).pow(2).mean(-1, keepdim=True)
        x = (x - mu) / torch.sqrt(std + self.variance_epsilon)
        return self.gamma * x + self.beta


ContextBERTPooler

  • (구 BERT의 Pooling) 첫번째 [CLS] 토큰만을 대상으로 수행
  • (CGBERT의 Pooling) local context attention mechanism
  • + Dense Layer & Activation Function을 거처서 출력됨
class ContextBERTPooler(nn.Module):
    def __init__(self, config):
        super(ContextBERTPooler, self).__init__()
        self.attention_gate = nn.Sequential(nn.Linear(config.hidden_size, 32),
                              nn.ReLU(),
                              nn.Dropout(config.hidden_dropout_prob),
                              nn.Linear(32, 1))
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states, attention_mask):
        #-------------- 일반 BERT와는 다른 Pooling --------------#
        #  In pooling, we are using a local context attention mechanism
        attn_scores = self.attention_gate(hidden_states)
        extended_attention_mask = attention_mask.unsqueeze(dim=-1)
        attn_scores = attn_scores.masked_fill(extended_attention_mask == 0, -1e9)
        attn_scores = F.softmax(attn_scores, dim=1)
        hs_pooled = torch.matmul(attn_scores.permute(0,2,1), hidden_states).squeeze(dim=1)
        #-----------------------------------------------------#
        
        pooled_output = self.dense(hs_pooled)
        pooled_output = self.activation(pooled_output)
        return pooled_output


[ Configuration ]

Configuration class to store the configuration of a BertModel

( 괄호 안의 값은 default 값 )

  • vocab_size : Vocabulary의 개수 (32000)
  • hidden_size : encoder layers and the pooler layer의 hidden dimension (768)
  • num_hidden_layers : Encoder의 hidden layer의 개수 (12)
  • num_attention_heads : attention head의 개수 (12)
  • intermediate_size : intermediate layer (FFNN)의 dimension (3072)
  • hidden_act : encoder ( & pooler )에서 사용되는 activation function (gelu)
  • hidden_dropout_prob : dropout probability (0.1)
  • max_position_embeddings : 최대 문장 길이 (512)
  • type_vocab_size : token_type_ids의 vocab 개수 (16)
  • initializer_range : weight를 initialize할때의 사용하는 Truncated Normal distn의 standard deviation (0.02)
class BertConfig(object):
    def __init__(self,
                vocab_size=32000,
                hidden_size=768,
                num_hidden_layers=12,
                num_attention_heads=12,
                intermediate_size=3072,
                hidden_act="gelu",
                hidden_dropout_prob=0.1,
                A_prob_dropout_prob=0.1,
                max_position_embeddings=512,
                type_vocab_size=16,
                initializer_range=0.02,
                full_pooler=False):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.A_prob_dropout_prob = A_prob_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.full_pooler = full_pooler

    #---------------------------------------------------------------#
    # Configuration을 [dictionary] 형태로 받은 뒤 Setting하기 
    #---------------------------------------------------------------#
    
    ## [ Dictionary ] 형태
    @classmethod
    def from_dict(cls, json_object):

        config = BertConfig(vocab_size=None)
        for (key, value) in six.iteritems(json_object):
            config.__dict__[key] = value
        return config

    ## [ Json ] 형태
    @classmethod
    # Configuration을 dictionary형태로 받은 뒤 Setting하기
    def from_json_file(cls, json_file):
        with open(json_file, "r") as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"