Post

TART reranker 원리와 코드 알아보기

메타에서 나온 Information Retrieval 기법. Task-aware Retrieval with Instructions이라는 논문에 잘 소개가 되어 있고, 깃허브도 있다.

핵심

TART의 핵심은 “retrieval에 검색할 때 Instruction을 포함시킨다” 에 있다. 단순한 유저의 질문이 아니라, 질문 문장에는 담기지 않은 의도까지 포함하여 검색할 수 있는 것이다. 예시를 들어보겠다.

예시

유저가 다음과 같은 질문을 한다고 하자.

질문 : Python에서 1부터 10까지 더하는 방법을 알려줘.

이러한 경우 유저에게는 두 가지 의도가 있을 수 있다.

  1. Python에서 1부터 10까지 더하는 코드를 작성해줘.
  2. Python에서 1부터 10까지 더하는 방법을 상세하게 줄글로 설명해줘.

즉, 이 경우 retrieve 해야하는 문서의 내용이 코드와 코드에 대한 풀이글, 두 가지로 나눠지는 것이다. 통상적인 경우에는 이런 경우 코드에 대한 retriever, 풀이글에 대한 retriever을 따로 사용했다.

하지만 TART는 유저의 의도를 함께 retrieve 할 때 넣어주고, 그러면 TART가 의도에 알맞게 문서를 찾아준다. 위의 예시에서 유저가 코드를 원했다면 코드 문서를, 풀이글을 원했다면 풀이글 문서를 retrieve 해 주는 것이다.

만든 방법

연구진들은 BERRI라는 instruction-query-정답 문서 쌍이 있는 데이터셋을 만들고, 그것으로 retriever를 학습시켰다. 더 자세한 부분은 논문을 참고해주시길.

TART-dual vs TART-full

TART는 두가지 모델이 있다. TART-dual은 보통의 임베딩 모델과 비슷하다. 유저의 쿼리와 각 문서들간의 유사도만을 비교한다. TART-full은 유저 쿼리와 각 문서들간의 관계 뿐 아니라, 문서들 간의 관계 역시 고려한다. instruction에 대해 어떤 문서가 더 가까운지 서로 비교해가며 연산하기 때문에 더욱 정확하다. 하지만 당연히 계산량이 훨씬 크겠지?

코드

RAGchain에서는 TART-full을 리랭커 형태로 준비했다. TART-dual의 모델은 허깅페이스에 올라와 있지 않고, TART-full의 연산량이 크기 때문이다. 실제로 논문에서도 TRAT-full은 다른 방법으로 retrieve 한 문서들에 사용했다고 하니, 딱 리랭커인 것이다. TART 리랭커가 추가된 PR은 여기서 확인할 수 있다.

간단한 TART 코드 예시는 아래와 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from typing import List
import torch
import torch.nn.functional as F
from .modeling_enc_t5 import EncT5ForSequenceClassification
from .tokenization_enc_t5 import EncT5Tokenizer

def tart_full(query: str, contents: List[str], instruction: str) -> List[str]:
    model_name = "facebook/tart-full-flan-t5-xl"
    model = EncT5ForSequenceClassification.from_pretrained(model_name)
    tokenizer = EncT5Tokenizer.from_pretrained(model_name)
    instruction_queries: List[str] = ['{0} [SEP] {1}'.format(instruction, query) for _ in range(len(contents))]

    features = tokenizer(instruction_queries, contents, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        scores = model(**features).logits
        normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]

    sorted_pairs = sorted(zip(contents, normalized_scores), key=lambda x: x[1], reverse=True)
    sorted_contents = [content for content, score in sorted_pairs]
    return sorted_contents

modeling_ent_t5와 tokenization_enc_t5는 원저자의 깃허브에도 있고 RAGchain에도 있다. TART 참 쉽죠?

This post is licensed under CC BY 4.0 by the author.