TART reranker 원리와 코드 알아보기
메타에서 나온 Information Retrieval 기법. Task-aware Retrieval with Instructions이라는 논문에 잘 소개가 되어 있고, 깃허브도 있다.
핵심
TART의 핵심은 “retrieval에 검색할 때 Instruction을 포함시킨다” 에 있다. 단순한 유저의 질문이 아니라, 질문 문장에는 담기지 않은 의도까지 포함하여 검색할 수 있는 것이다. 예시를 들어보겠다.
예시
유저가 다음과 같은 질문을 한다고 하자.
질문 : Python에서 1부터 10까지 더하는 방법을 알려줘.
이러한 경우 유저에게는 두 가지 의도가 있을 수 있다.
- Python에서 1부터 10까지 더하는 코드를 작성해줘.
- Python에서 1부터 10까지 더하는 방법을 상세하게 줄글로 설명해줘.
즉, 이 경우 retrieve 해야하는 문서의 내용이 코드와 코드에 대한 풀이글, 두 가지로 나눠지는 것이다. 통상적인 경우에는 이런 경우 코드에 대한 retriever, 풀이글에 대한 retriever을 따로 사용했다.
하지만 TART는 유저의 의도를 함께 retrieve 할 때 넣어주고, 그러면 TART가 의도에 알맞게 문서를 찾아준다. 위의 예시에서 유저가 코드를 원했다면 코드 문서를, 풀이글을 원했다면 풀이글 문서를 retrieve 해 주는 것이다.
만든 방법
연구진들은 BERRI라는 instruction-query-정답 문서 쌍이 있는 데이터셋을 만들고, 그것으로 retriever를 학습시켰다. 더 자세한 부분은 논문을 참고해주시길.
TART-dual vs TART-full
TART는 두가지 모델이 있다. TART-dual은 보통의 임베딩 모델과 비슷하다. 유저의 쿼리와 각 문서들간의 유사도만을 비교한다. TART-full은 유저 쿼리와 각 문서들간의 관계 뿐 아니라, 문서들 간의 관계 역시 고려한다. instruction에 대해 어떤 문서가 더 가까운지 서로 비교해가며 연산하기 때문에 더욱 정확하다. 하지만 당연히 계산량이 훨씬 크겠지?
코드
RAGchain에서는 TART-full을 리랭커 형태로 준비했다. TART-dual의 모델은 허깅페이스에 올라와 있지 않고, TART-full의 연산량이 크기 때문이다. 실제로 논문에서도 TRAT-full은 다른 방법으로 retrieve 한 문서들에 사용했다고 하니, 딱 리랭커인 것이다. TART 리랭커가 추가된 PR은 여기서 확인할 수 있다.
간단한 TART 코드 예시는 아래와 같다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from typing import List
import torch
import torch.nn.functional as F
from .modeling_enc_t5 import EncT5ForSequenceClassification
from .tokenization_enc_t5 import EncT5Tokenizer
def tart_full(query: str, contents: List[str], instruction: str) -> List[str]:
model_name = "facebook/tart-full-flan-t5-xl"
model = EncT5ForSequenceClassification.from_pretrained(model_name)
tokenizer = EncT5Tokenizer.from_pretrained(model_name)
instruction_queries: List[str] = ['{0} [SEP] {1}'.format(instruction, query) for _ in range(len(contents))]
features = tokenizer(instruction_queries, contents, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
scores = model(**features).logits
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
sorted_pairs = sorted(zip(contents, normalized_scores), key=lambda x: x[1], reverse=True)
sorted_contents = [content for content, score in sorted_pairs]
return sorted_contents
modeling_ent_t5와 tokenization_enc_t5는 원저자의 깃허브에도 있고 RAGchain에도 있다. TART 참 쉽죠?