Code Review for HAN (Hierarchical Attention Network)

Reference

  • https://simonjisu.github.io/nlp/2018/07/05/packedsequence.html

  • https://github.com/Hazoom/bert-han/blob/master/src/models/han.py
  • https://github.com/Hazoom/bert-han/blob/master/src/models/bert_wordattention.py
  • https://github.com/Hazoom/bert-han/blob/master/src/models/sentenceattention.py


(참고) pad & pack sequence

다음과 같은 Input ( input_seq2idx ) 이 있다고 해보자

  • 배치 크기 = 5 ( 5개의 문장 )
  • 문장 최대 길이 = 6 ( 제일 긴 문장의 token 수는 6개 )
  • 히든 크기 = 2 ( hidden dimension )
input_seq2idx
============================================
tensor([[  1,  16,   7,  11,  13,   2],
        [  1,  16,   6,  15,   8,   0],
        [ 12,   9,   0,   0,   0,   0],
        [  5,  14,   3,  17,   0,   0],
        [ 10,   0,   0,   0,   0,   0]])

figure2


위의 Input ( input_seq2idx )에 embedding을 거쳐서

  • embeded = embed(input_seq2idx) 와 같이 임베딩을 한다.

그런 뒤…


(1) pack_padded_sequence

  • packed_output = pack_padded_sequence(embeded, input_lengths.tolist(), batch_first=True)

  • Output의 크기 :

    • packed_output[0].size() : torch.Size([18, 2]) … 18개의 문장 & 2개의 hidden dimension
    • packed_output[1] : tensor([ 5, 4, 3, 3, 2, 1]))


(2) pad_packed_sequence

  • output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)

  • Output의 크기 :

    • output.size() : torch.Size([5, 6, 2]
    • output_lengths : tensor([ 6, 5, 4, 2, 1]))


1. Word Attention

예시 data : 100개의 문서

  • (5,7,3,12,….,8)개의 문장 (=총 1050 문장) , 최대 문장 길이 =12
    • (54,23,82,25,….77)개의 단어 (=총 8만 단어) , 최대 단어 길이 =82


[INPUT 소개] 의미 및 차원 :

  • docs : encoding된 문서… size = ( 문서 개수, 패딩 O 문서 길이, 패딩 O 문장 길이 ) = (100,12,82)
  • doc_lengths : 문서 길이… size = ( 문서 개수 ) = 100
    • 값 : (5,7,3,12,….,8)
  • sent_lengths : 문장 길이…size = ( 문서 개수, 최대 문장 길이 ) = (100,12)
@registry.register("word_attention", "WordAttention")
class WordAttention(nn.Module):
    def __init__(self,
            device: str,
            preprocessor: AbstractPreproc,
            word_emb_size: int,
            dropout: float,
            recurrent_size: int,
            attention_dim: int):
        super().__init__()
        self._device = device
        self.preprocessor = preprocessor
        self.embedder: abstract_embeddings.Embedder = self.preprocessor.get_embedder()
        self.vocab: vocab.Vocab = self.preprocessor.get_vocab()
        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        self.dropout = dropout
        self.attention_dim = attention_dim
        self.embedding = nn.Embedding(num_embeddings=len(self.vocab), embedding_dim=self.word_emb_size)
        assert self.recurrent_size % 2 == 0 # bi-LSTM할것이기 때문에!
        assert self.word_emb_size == self.embedder.dim


        # init embedding
        init_embed_list = []
        for index, word in enumerate(self.vocab):
            # (1) embedder에 있을 경우 -> LookUp Table에서 찾아
            if self.embedder.contains(word):
                init_embed_list.append(self.embedder.lookup(word))
            # (2) embedder에 없을 경우 -> 새롭게 embedding
            else:
                init_embed_list.append(self.embedding.weight[index])
        init_embed_weight = torch.stack(init_embed_list, 0)
        self.embedding.weight = nn.Parameter(init_embed_weight)

        self.encoder = nn.LSTM(
            input_size=self.word_emb_size,
            hidden_size=self.recurrent_size // 2, # (bi-LSTM이므로)
            dropout=self.dropout,
            bidirectional=True,
            batch_first=True,
        )
        self.dropout = nn.Dropout(dropout)
        self.word_weight = nn.Linear(self.recurrent_size, self.attention_dim)
        self.context_weight = nn.Linear(self.attention_dim, 1)

    def recurrent_size(self):
        return self.recurrent_size

    def forward(self, docs, doc_lengths, sent_lengths):
        #################
        ### 계층 구조 1 ##
        #################
        #-----------------------------------------#
        # (Step 1) SENT 많은 DOC 순으로 sort
        ### doc_lengths = 100(=5,7,3,12,....,8) -> 100(=12,11,....3)
        ### sent_lengths = 100x12(=54,23,...77) -> 100x12(=71,77..12)
        doc_lengths, doc_perm_idx = doc_lengths.sort(dim=0, descending=True)
        docs = docs[doc_perm_idx]
        sent_lengths = sent_lengths[doc_perm_idx] 
        
        #-----------------------------------------#
        # (Step 2) [packing] 여러 DOC -> 여러 여러 SENT
        ##### ( Make a long batch of sentences )
        ##### BEFORE : (num_docs, padded_doc_length, padded_sent_length)
        ##### AFTER :  (num_sents, padded_sent_length)
        ##### docs = 100x12x82 -> sents = 1050x82
        packed_sents = pack_padded_sequence(docs, lengths=doc_lengths.tolist(), batch_first=True)
        doc_bs = packed_sents.batch_sizes 
        sents = packed_sents.data 
        
        #-------------------------------------------------#
        # (Step 3) [packing] 여러 DOC 여러 SENT 개수 -> 여러 여러 SENT 개수 
        ##### ( Make a long batch of sentences lengths  )
        ##### BEFORE : (num_docs, padded_doc_length)
        ##### AFTER :  (num_sents)
        ##### sent_lengths = 100x12(=54,23,...77) -> 1050 
        packed_sent_lengths = pack_padded_sequence(sent_lengths, lengths=doc_lengths.tolist(), batch_first=True)
        sent_lengths = packed_sent_lengths.data
        
		#-------------------------------------------------#
        # (Step 4) WORD 많은 SENT 순으로 sort
        sent_lengths, sent_perm_idx = sent_lengths.sort(dim=0, descending=True)
        sents = sents[sent_perm_idx]
		#-----------------------------------------#
        
        #################
        ### 계층 구조 2 ##
        #################
        #-----------------------------------------#
        # (Step 1) [embedding] 여러 여러 SENT
        input_ = self.dropout(self.embedding(sents))
        #-----------------------------------------#
        # (Step 2) [packing] 여러 여러 SENT -> 여러 여러 여러 WORD
        packed_words = pack_padded_sequence(input_, lengths=sent_lengths.tolist(), batch_first=True)
        #-----------------------------------------#
        # (Step 3) [encoding] 여러 여러 여러 WORD
        packed_words, _ = self.encoder(packed_words)
        sentences_bs = packed_words.batch_sizes
	    #-----------------------------------------#
        
		#################
        ### Attention ###
        #################
        u_i = torch.tanh(self.word_weight(packed_words.data))
        u_w = self.context_weight(u_i).squeeze(1)
        att = torch.exp(u_w - u_w.max())

		##########################
        ### 다시 Padding 시키기 ###
        #########################
        # Restore as sentences by repadding
        att, _ = pad_packed_sequence(PackedSequence(att, sentences_bs), batch_first=True)
        att_weights = att / torch.sum(att, dim=1, keepdim=True)

        # Restore as sentences by repadding
        sents, _ = pad_packed_sequence(packed_words, batch_first=True)
        sents = sents * att_weights.unsqueeze(2)
        sents = sents.sum(dim=1)

        ##########################
        ### 다시 Sorting 시키기 ###
        #########################
        _, sent_unperm_idx = sent_perm_idx.sort(dim=0, descending=False)
        sents = sents[sent_unperm_idx]
        att_weights = att_weights[sent_unperm_idx]

        return sents, doc_perm_idx, doc_bs, att_weights

2. Sentence Attention

[INPUT 소개] 의미 및 차원 :

  • sent_embeddings : 임베딩된 문장들 … size : (batch_size * padded_doc_length, sentence recurrent dim)
  • doc_perm_idx : 문서 순서 index …. size : (batch_size)
  • doc_bs : 문서의 batch size …. size : (max_doc_len)
  • word_att_weights : word attention weights …. size : (batch_size * padded_doc_length, max_sent_len)


@registry.register("sentence_attention", "SentenceAttention")
class SentenceAttention(torch.nn.Module):
    def __init__(
            self,
            device: str,
            dropout: float,
            word_recurrent_size: int,
            recurrent_size: int,
            attention_dim: int,
    ):
        super().__init__()
        self._device = device
        self.word_recurrent_size = word_recurrent_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0
        self.dropout = dropout
        self.attention_dim = attention_dim
        self.encoder = nn.LSTM(
            input_size=self.word_recurrent_size,
            hidden_size=self.recurrent_size // 2,
            dropout=self.dropout,
            bidirectional=True,
            batch_first=True,
        )
        self.dropout = nn.Dropout(dropout)
        # LSTM 임베딩 결과 -> Attention 차원
        self.sentence_weight = nn.Linear(self.recurrent_size, self.attention_dim)
        # Attention 차원 -> Attention score
        self.sentence_context_weight = nn.Linear(self.attention_dim, 1)
		
        
    def recurrent_size(self):
        return self.recurrent_size

    def forward(self, sent_embeddings, doc_perm_idx, doc_bs, word_att_weights):
	    #-----------------------------------------#
        # [Step 1] Sentence embedding에 드롭아웃
        sent_embeddings = self.dropout(sent_embeddings)
        
        #-----------------------------------------#
        # [Step 2] Sentence embedding을 인코딩 (with LSTM)
        packed_sentences, _ = self.encoder(PackedSequence(sent_embeddings, doc_bs))

        #-----------------------------------------#
        # [Step 3] Attention 계산
        u_i = torch.tanh(self.sentence_weight(packed_sentences.data))
        u_w = self.sentence_context_weight(u_i).squeeze(1)
        att = torch.exp(u_w - u_w.max())

        # Restore as sentences by repadding
        att, _ = pad_packed_sequence(PackedSequence(att, doc_bs), batch_first=True)
        sent_att_weights = att / torch.sum(att, dim=1, keepdim=True)

        # Restore as documents by repadding
        docs, _ = pad_packed_sequence(packed_sentences, batch_first=True)

        # Compute document vectors
        docs = docs * sent_att_weights.unsqueeze(2)
        docs = docs.sum(dim=1)

        # Restore as documents by repadding
        word_att_weights, _ = pad_packed_sequence(PackedSequence(word_att_weights, doc_bs), batch_first=True)

        # Restore the original order of documents (undo the first sorting)
        _, doc_unperm_idx = doc_perm_idx.sort(dim=0, descending=False)
        docs = docs[doc_unperm_idx]

        word_att_weights = word_att_weights[doc_unperm_idx]
        sent_att_weights = sent_att_weights[doc_unperm_idx]

        return docs, word_att_weights, sent_att_weights


3. HAN (Hierarchical Attention Network)

차원 :

  • docs : (num_docs, padded_doc_length, padded_sent_length)

  • doc_lengths : (num_docs)

  • sent_lengths : (num_docs, max_sent_len)

  • attention_masks: (num_docs, padded_doc_length, padded_sent_length)

    ( = docs와 동일한 size )

  • token_type_ids : (num_docs, padded_doc_length, padded_sent_length)

    ( = docs와 동일한 size )


@registry.register('model', 'HAN')
class HANModel(torch.nn.Module):
    # class Preprocessor(abstract_preprocessor.AbstractPreproc): 생략

    def __init__(self, preprocessor, device, word_attention, sentence_attention, final_layer_dim, final_layer_dropout):
        super().__init__()
        # (1) Preprocessor
        self.preprocessor = preprocessor
        
        # (2) Word Attention
        self.word_attention = registry.instantiate(
            callable=registry.lookup("word_attention", word_attention["name"]),
            config=word_attention,
            unused_keys=("name",),
            device=device,
            preprocessor=preprocessor.preprocessor)
        
        # (3) Sentence Attention
        self.sentence_attention = registry.instantiate(
            callable=registry.lookup("sentence_attention", sentence_attention["name"]),
            config=sentence_attention,
            unused_keys=("name",),
            device=device)
        
	    # (4) FFNN
        self.mlp = nn.Sequential(
            torch.nn.Linear(self.sentence_attention.recurrent_size, final_layer_dim), 
            nn.ReLU(), 
            nn.Dropout(final_layer_dropout),
            torch.nn.Linear(final_layer_dim, self.preprocessor.get_num_classes())
        )
	    
        # (5) Cross Entropy Loss
        self.loss = nn.CrossEntropyLoss(reduction="mean").to(device)
        
    ##########################################################################################
    
    def forward(self, docs, doc_lengths, sent_lengths, 
                labels=None, attention_masks=None, token_type_ids=None):
	    #----------------------------------------------------------------------#
        # [STEP 1] Word attention을 사용해서 Sentence 임베딩
        if attention_masks is not None and token_type_ids is not None:
            sent_embeddings, doc_perm_idx, doc_bs, word_att_weights = self.word_attention(
                docs, doc_lengths, sent_lengths, attention_masks, token_type_ids)
        else:
            sent_embeddings, doc_perm_idx, doc_bs, word_att_weights = self.word_attention(
                docs, doc_lengths, sent_lengths)
            
        #----------------------------------------------------------------------#
        # [STEP 2] Sentence attention을 사용해서 Document 임베딩
        doc_embeds, word_att_weights, sentence_att_weights = self.sentence_attention(
            sent_embeddings, doc_perm_idx, doc_bs, word_att_weights
        )
        
	    #----------------------------------------------------------------------#
        # [STEP 3] 최종 Output
        ### 1) Score (Document Classification Result)
        ### 2) Attention weight ( word & sentence )
        scores = self.mlp(doc_embeds)
        outputs = (scores, word_att_weights, sentence_att_weights,)
        if labels is not None:
            loss = self.loss(scores, labels)
            return outputs + (loss,)
        return outputs