LLM을 이용한 의미 기반 검색
쉽고 빠르게 익히는 실전 LLM (https://product.kyobobook.co.kr/detail/S000212147276)
실습 2. Cross-encoder를 사용한 검색 재순위화
( 참고: https://github.com/sinanuozdemir/quick-start-guide-to-llms/blob/main/notebooks/02_semantic_search.ipynb )
Procedures
- Step 1) 관련 패키지 불러오기
- Step 2) 재순위화 하는 함수
- Step 3) 재순위화를 위한 모델
- Step 4) 테스트 데이터셋에 있는 질문들 가져오기
- Step 5) 특정 query에 대한 정보/정답 반환
Step 1) 관련 패키지 불러오기
from sentence_transformers.cross_encoder import CrossEncoder
import numpy as np
from torch import nn
from copy import copy
Step 2) 재순위화하는 함수
from copy import copy
def get_results_from_pinecone(query, top_k=3, re_rank_model=None, verbose=True, correct_hash=None):
# (1) top K의 유관한 정보 가져오기
results_from_pinecone = query_from_pinecone(query, top_k=top_k)
if not results_from_pinecone:
return []
if verbose:
print("Query:", query)
final_results = []
retrieved_correct_position, reranked_correct_position = None, None
for idx, result_from_pinecone in enumerate(results_from_pinecone):
if correct_hash and result_from_pinecone['id'] == correct_hash:
retrieved_correct_position = idx
# (2) Re-rank하는 모델을 사용한 재순위화
if re_rank_model is not None:
if verbose:
print('Document ID (Hash)\t\tRetrieval Score\tCE Score\tText')
# (2-1) (질문, 응답) pair
sentence_combinations = [[query, result_from_pinecone['metadata']['text']] for result_from_pinecone in results_from_pinecone]
# (2-2) (질문, 응답) pair의 유사도 계산
similarity_scores = re_rank_model.predict(sentence_combinations, activation_fct=nn.Sigmoid())
# (2-3) 유사도 순으로 정렬
sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
sim_scores_sort = list(reversed(np.sort(similarity_scores)))
top_re_rank_score = sim_scores_sort[0]
# Print the scores
for idx in sim_scores_argsort:
result_from_pinecone = results_from_pinecone[idx]
if correct_hash and result_from_pinecone['id'] == correct_hash:
reranked_correct_position = idx
final_results.append({'score': similarity_scores[idx], 'id': result_from_pinecone['id'], 'metadata': result_from_pinecone['metadata']})
if verbose:
print(f"{result_from_pinecone['id']}\t{result_from_pinecone['score']:.2f}\t{similarity_scores[idx]:.6f}\t{result_from_pinecone['metadata']['text'][:50]}")
return {'final_results': final_results, 'retrieved_correct_position': retrieved_correct_position, 'reranked_correct_position': reranked_correct_position, 'results_from_pinecone': results_from_pinecone, 'top_re_rank_score': top_re_rank_score}
if verbose:
print('Document ID (Hash)\t\tRetrieval Score\tText')
for result_from_pinecone in results_from_pinecone:
final_results.append(result_from_pinecone)
if verbose:
print(f"{result_from_pinecone['id']}\t{result_from_pinecone['score']:.2f}\t{result_from_pinecone['metadata']['text'][:50]}")
return {'final_results': final_results, 'retrieved_correct_position': retrieved_correct_position, 'reranked_correct_position': reranked_correct_position}
Step 3) 재순위화를 위한 모델 (Re-ranking model) 불러오기
사전학습된 CrossEncoder 불러오기
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', num_labels=1)
Step 4) test 데이터셋에 있는 질문(query)들 가져오기
unique_inputs = list(set(dataset['test']['question']))
print(len(unique_inputs))
1148
특정 질문 (query) 예시 확인하기
query = unique_inputs[0]
print(query)
How close did John make it to Johnston Atoll?
특정 질문 (query)에 해당하는 내용 (content) 확인하기
for t in dataset['test']:
if t['question'] == query:
print(t['context'])
The origins of Hurricane John were thought by the United States National Hurricane Center (NHC) to be from a tropical wave that moved off the coast of Africa on July 25, 1994. The wave subsequently moved across the Atlantic Ocean and Caribbean without distinction, before it crossed Central America and moved into the Eastern Pacific Ocean on or around August 8.....
특정 질문 (query)에 대해, 이에 해당하는 내용(content)를 hashing한 결과를 담은 dictionary
q_to_hash = {data['question']: my_hash(data['context']) for data in dataset['test']}
Step 5) 특정 query에 대한 정보/정답 반환
5-1) 상위 2개의 결과값을 확인!
query_result = get_results_from_pinecone(
query,
top_k=2,
re_rank_model=cross_encoder,
correct_hash=q_to_hash[query],
verbose=False
)
결과: (cross일치/상응하는 정보 없음!
query_result['retrieved_correct_position'], query_result['reranked_correct_position']
(None, None)
query
{'final_results':
[{'score': 0.99843186,
'id': 'a76b6a3dfcbdb1ca832bbf40710ad2c8',
'metadata': {'date_uploaded': '2024-02-04T15:46:32.060499',
'text': "John affected b...
..
{'score': 0.991396,
'id': '8f3fd30f7d46c05089f7f84d71806b77',
'metadata': {'date_uploaded': '2024-02-04T15:46:41.323346',
'text': "Clearing Johnston Atoll
..
'retrieved_correct_position': None,
'reranked_correct_position': None,
..
5-2) 상위 10개의 결과값을 확인!
query_result = get_results_from_pinecone(
query,
top_k=10,
re_rank_model=cross_encoder, correct_hash=q_to_hash[query],
verbose=False
)
query_result['retrieved_correct_position'], query_result['reranked_correct_position']
(2, 2)
\(\rightarrow\) 10개의 결과 중 3번째 (position=2) 정보가 유관한 대답을 담고 있음!
교훈
- 무작정 유사도순위 높은 순? No! 원하는 대답이 아닐 수 있음
- 모델을 통해 (query, answer) re-ranking을 통해 순위를 다시 매겨봐라!
Summary
교훈
- 무작정 유사도순위 높은 순? No! 원하는 대답이 아닐 수 있음
- 모델을 통해 (query, answer) re-ranking을 통해 순위를 다시 매겨봐라!
전체 코드
test_sample = dataset['test']
TOP_K=50
logger.setLevel(logging.CRITICAL)
predictions = []
for question in tqdm(test_sample['question']):
r = get_results_from_pinecone(
question, top_k=TOP_K, re_rank_model=cross_encoder, correct_hash=q_to_hash[question],
verbose=False
)
predictions.append(r)
if len(predictions) % 100 == 0:
retrieved_accuracy = sum([_['retrieved_correct_position'] == 0 for _ in predictions])/len(predictions)
re_ranked_accuracy = sum([_['reranked_correct_position'] == 0 for _ in predictions])/len(predictions)
print(f'Accuracy without re-ranking: {retrieved_accuracy}')
print(f'Accuracy with re-ranking: {re_ranked_accuracy}')
9%|▊ | 100/1148 [02:39<27:57, 1.60s/it]
Accuracy without re-ranking: 0.78
Accuracy with re-ranking: 0.84
17%|█▋ | 200/1148 [05:19<24:58, 1.58s/it]
Accuracy without re-ranking: 0.765
Accuracy with re-ranking: 0.835
26%|██▌ | 300/1148 [07:59<22:08, 1.57s/it]
Accuracy without re-ranking: 0.7666666666666667
Accuracy with re-ranking: 0.8166666666666667
35%|███▍ | 400/1148 [12:10<19:30, 1.56s/it]
Accuracy without re-ranking: 0.7625
Accuracy with re-ranking: 0.825
44%|████▎ | 500/1148 [14:50<17:13, 1.59s/it]
Accuracy without re-ranking: 0.764
Accuracy with re-ranking: 0.834
52%|█████▏ | 600/1148 [17:29<14:38, 1.60s/it]
Accuracy without re-ranking: 0.7683333333333333
Accuracy with re-ranking: 0.8466666666666667
61%|██████ | 700/1148 [20:08<12:14, 1.64s/it]
Accuracy without re-ranking: 0.7471428571428571
Accuracy with re-ranking: 0.8285714285714286
70%|██████▉ | 800/1148 [22:46<09:11, 1.58s/it]
Accuracy without re-ranking: 0.74875
Accuracy with re-ranking: 0.82125
78%|███████▊ | 900/1148 [25:26<06:36, 1.60s/it]
Accuracy without re-ranking: 0.7388888888888889
Accuracy with re-ranking: 0.8188888888888889
87%|████████▋ | 1000/1148 [28:05<04:05, 1.66s/it]
Accuracy without re-ranking: 0.741
Accuracy with re-ranking: 0.819
96%|█████████▌| 1100/1148 [30:45<01:16, 1.60s/it]
Accuracy without re-ranking: 0.7509090909090909
Accuracy with re-ranking: 0.8272727272727273
100%|██████████| 1148/1148 [32:01<00:00, 1.67s/it]
`