Code Review for “BERT for ABSA”

Reference :

  • https://github.com/akkarimi/BERT-For-ABSA/blob/master/src/absa_data_utils.py
  • https://github.com/akkarimi/BERT-For-ABSA/blob/master/src/run_ae.py
  • https://github.com/akkarimi/BERT-For-ABSA/blob/master/src/run_asc.py
  • https://github.com/akkarimi/BERT-For-ABSA/blob/master/src/run_ae.py
  • https://github.com/akkarimi/BERT-For-ABSA/blob/master/src/run_asc.py


( AE vs ASC )

Aspect Extraction (AE):

  • given a (1) review sentence (“The retina display is great.”), find aspects(“retina display”);

Aspect Sentiment Classification (ASC):

  • given an (1) aspect (“retina display”) and a (2) review sentence (“The retina display is great.”), detect the polarity of that aspect (positive).


1. BERT for ASC(Aspect Sentiment Classification)

bert_forward :

  • 3개의 embedding, 12개의 encoder, pooling 다 거친 결과물

adv_attack :

  • gradient 방향 noise 껴서 adversarial attack 데이터 생성하기 ( = perturbed sentence )

adversarial_loss :

  • perturbed sentence ( + mask )를 input으로 넣어서 Cross Entropy loss 계산

forward :

  • 과정 1) bert_forward
  • 과정 2) prediction하기 ( = dropout & classifier로 logit값 생성 )
  • 과정 3)
    • (case 1) Y값 있는 경우
      • training 과정 X 경우 : loss 계산
      • training 과정 O 경우 : loss 계산 & adversarial loss 계산
    • (case 2) Y값 없는 경우
      • logit값 반환
class BertForABSA(BertModel):
    def __init__(self, config, num_classes=3, dropout=None, epsilon=None):
        super(BertForABSA, self).__init__(config)
        self.num_classes = num_classes 
        self.epsilon = epsilon
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(config.hidden_size, num_classes)
        self.loss_fn = nn.CrossEntropyLoss()
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        pooled, embedded = self.bert_forward(input_ids, token_type_ids, 
                                             attention_mask=attention_mask, 
                                             output_all_encoded_layers=False)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        
        # Case 1) Y값 (O) 경우
        if labels is not None:
            loss_ = self.loss_fn(logits.view(-1, self.num_classes), labels.view(-1))
            
            # (training 모드)
            if pooled.requires_grad: 
                perturbed_snt = self.adv_attack(embedded, loss_)
                perturbed_snt = self.replace_cls_token(embedded, perturbed_snt)
                adv_loss = self.adversarial_loss(perturbed_snt, attention_mask, labels)
                return loss_, adv_loss
            return loss_
        
        # Case 2) Y값 (X) 경우
        else:
            return logits

    #--------- Adversarial Attack 데이터 생성 ------------#
    def adv_attack(self, emb, loss):
        loss_grad = grad(loss, emb, retain_graph=True)[0]
        loss_grad_norm = torch.sqrt(torch.sum(loss_grad**2, (1,2)))
        perturbed_snt = emb + self.epsilon * (loss_grad/(loss_grad_norm.reshape(-1,1,1)))
        return perturbed_snt
    #---------------------------------------------------#
    def replace_cls_token(self, emb, perturbed):
        condition=torch.zeros_like(emb)
        condition[:, 0, :] = 1
        perturbed_snt = torch.where(condition.byte(), emb, perturbed)
        return perturbed_snt
    
    #--------BERT 통한 최종 Embedding 결과----------------#
    def bert_forward(self, input_ids, token_type_ids=None, 
                            attention_mask=None, output_all_encoded_layers=False):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        mask3d = attention_mask.unsqueeze(1).unsqueeze(2)
        mask3d = mask3d.to(dtype=next(self.parameters()).dtype) 
        mask3d = (1.0 - mask3d) * -10000.0
        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output, mask3d,output_all_encoded_layers=output_all_encoded_layers)
        pooled = self.pooler(encoded_layers[-1])
        return pooled, embedding_output
    #---------------------------------------------------#

    def adversarial_loss(self, perturbed, attention_mask, labels):
        #------ (1) mask 생성 --------#
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        mask3d = attention_mask.unsqueeze(1).unsqueeze(2)
        mask3d = mask3d.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        mask3d = (1.0 - mask3d) * -10000.0
        
        #------ (2) Encoding 하기--------#
        encoded_layers = self.encoder(perturbed, mask3d,output_all_encoded_layers=False)
        layers_wo_last = self.pooler(encoded_layers[-1])
        layers_wo_last = self.dropout(layers_wo_last)
        
        #------ (3) Predicton & Loss 계산--------#
        logits = self.classifier(layers_wo_last)
        adv_loss = self.loss_fn(logits.view(-1, self.num_classes), labels.view(-1))
        return adv_loss


2. BERT for AE(Aspect Extraction)

[ AE & ASC 차이점 ]

차이점 1) 목표

  • AE : 모든 단어별로 aspect 여부 check
  • ASC : aspect 단어만을 대상으로 감정 분석


차이점 2) tokenize

  • AE : sub-word tokenization
  • ASC : word tokenization
for (ex_index, example) in enumerate(examples):
    if mode!="ae":
        tokens_a = tokenizer.tokenize(example.text_a)
        
    else: #only do subword tokenization.
        tokens_a, labels_a, example.idx_map= tokenizer.subword_tokenize([token.lower() for token in example.text_a], example.label )

( AE의 “subword_tokenize”를 자세히 들여다보면… )

class ABSATokenizer(BertTokenizer):     
    def subword_tokenize(self, tokens, labels): # for AE
        split_tokens, split_labels= [], []
        idx_map=[]
        for ix, token in enumerate(tokens):
            sub_tokens=self.wordpiece_tokenizer.tokenize(token)
            for jx, sub_token in enumerate(sub_tokens):
                split_tokens.append(sub_token)
                if labels[ix]=="B" and jx>0:
                    split_labels.append("I")
                else:
                    split_labels.append(labels[ix])
                idx_map.append(ix)
        return split_tokens, split_labels, idx_map


차이점 3) bert_forward의 결과로 나오는 output

  • (1) ASC : pooled된 output & embedded
  • (2) AE : sequence output & embedded


class BertForABSA(BertModel):
    def __init__(self, config, num_classes=3, dropout=None, epsilon=None):
        super(BertForABSA, self).__init__(config)
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(config.hidden_size, num_classes)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        sequence_output, embedded = self.bert_forward(input_ids, 
                                                token_type_ids, 
                                                attention_mask, 
                                                output_all_encoded_layers=False)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
            loss_ = loss_fn(logits.view(-1, self.num_classes), labels.view(-1))
            if sequence_output.requires_grad: 
                perturbed_sentence = self.adv_attack(embedded, loss_, self.epsilon)
                adv_loss = self.adversarial_loss(perturbed_sentence, attention_mask, labels)
                return _loss, adv_loss
            return _loss
        else:
            return logits

    def adv_attack(self, emb, loss, epsilon):
        loss_grad = grad(loss, emb, retain_graph=True)[0]
        loss_grad_norm = torch.sqrt(torch.sum(loss_grad**2, (1,2)))
        perturbed_sentence = emb + epsilon * (loss_grad/(loss_grad_norm.reshape(-1,1,1)))
        return perturbed_sentence

    def adversarial_loss(self, perturbed, attention_mask, labels):
        #------ (1) mask 생성 --------#
        mask3d = attention_mask.unsqueeze(1).unsqueeze(2)
        mask3d = mask3d.to(dtype=next(self.parameters()).dtype)
        mask3d = (1.0 - mask3d) * -10000.0
        
        #------ (2) Encoding 하기--------#
        encoded_layers = self.encoder(perturbed, mask3d,output_all_encoded_layers=False)
        encoded_layers_last = encoded_layers[-1]
        encoded_layers_last = self.dropout(encoded_layers_last)
        
        #------ (3) Predicton & Loss 계산--------#
        #### 유의점 : (ingore_index=-1)
        loss_fn = nn.CrossEntropyLoss(ignore_index=-1) 
        logits = self.classifier(encoded_layers_last)
        adv_loss = loss_fn(logits.view(-1, self.num_classes), labels.view(-1))
        return adv_loss

    #--------BERT 통한 최종 Embedding 결과----------------#
    def bert_forward(self, input_ids, token_type_ids=None, 
                        attention_mask=None, output_all_encoded_layers=False):
        #-------------------------------------#
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        #-------------------------------------#
        mask3d = attention_mask.unsqueeze(1).unsqueeze(2)
        mask3d = mask3d.to(dtype=next(self.parameters()).dtype) 
        mask3d = (1.0 - mask3d) * -10000.0
        #-------------------------------------#
        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output, mask3d, output_all_encoded_layers=output_all_encoded_layers)
        layers_wo_last = encoded_layers[-1]
        return layers_wo_last, embedding_output