https://arxiv.org/abs/2007.04825
Fast Transformers with Clustered Attention
Transformers have been proven a successful model for a variety of tasks in sequence modeling. However, computing the attention matrix, which is their key component, has quadratic complexity with respect to the sequence length, thus making them prohibitivel
arxiv.org
Abstract.
Transformers는 sequence modeling의 다양한 작업에서 성공적인 모델임이 입증되었습니다. 하지만 transformer의 핵심 구성 요소인 attention matrix를 계산하는 것은 sequence 길이에 대해 quadratic complexity를 가지며, 이로 인해 긴 sequence에 대해서는 연산 비용이 엄청나게 높아집니다. 이 문제를 해결하기 위해 본 논문의 저자들은 clustered attention을 제안합니다. 이는 모든 queries들에 대해 개별적으로 attention을 계산하는 대신, queries들을 여러 군집(clusters)으로 그룹화하고 오직 각 군집의 centroids에 대해서만 attention을 계산합니다. 이러한 approximation을 더욱 개선하기 위해, 계산된 군집들을 사용하여 각 query당 가장 높은 attention을 받는 keys들을 식별하고 해당 key/query에 대해서만 정확한 내적을 추가로 계산합니다.
결과적으로, 군집의 개수가 고정되어 있을 때 sequence 길이에 대해 linear complexity를 가지는 모델이 됩니다. 본 논문에서는 두 개의 자동 음성 인식(ASR) 데이터셋을 이용해 평가했으며, 주어진 연산 예산 내에서 제안된 모델이 vanilla transformer의 성능을 일관되게 능가함을 보여줍니다. 마지막으로, pretrained BERT 모델을 GLUE 및 SQuAD 벤치마크에서 단 25개의 군집만으로 성능 저하 없이 근사해 냄으로써, 제안된 모델이 최소한의 군집 개수만으로도 임의의 복잡한 attention 분포를 성공적으로 근사할 수 있음을 입증합니다.
Introduction.
sequence modeling은 NN machine translation, image captioning, summarization, ASR 및 합성 등 다양한 application에 필수적인 머신러닝의 핵심 작업입니다. transformers는 앞서 언급한 대부분의 작업에서 SOTA 성능을 크게 발전시킨 강력한 도구임이 입증되었습니다. 특히, transformer는 self-attention을 도입하여 RNN에 내재된 vanishing-gradient 문제 없이 긴 sequence를 처리할 수 있게 해줍니다.
그럼에도 불구하고, self-attention의 사용은 sequence 길이에 대해 quadratic으로 확장되는 연산 및 메모리 요구 사항을 수반하며, 이는 긴 sequence에 대한 적용 가능성을 제한합니다. 입력 sequence를 queries와 keys로 나누고 각 queries가 모든 key에 attending하는 self-attention의 핵심 메커니즘을 고려하면 quadratic는 명백해집니다. 이를 해결하기 위해 최근에는 이러한 한계를 극복하는 방법 개발에 대한 관심이 높아지고 있습니다.
이러한 방법들은 크게 두 가지 뚜렷한 연구 방향으로 분류할 수 있습니다. 하나는 self-atteniton의 연산의 점근적 복잡도(asymptotic complexity)를 개선하는 데 초점을 맞추는 연구들이며, 다른 하나는 self-attention의 quadratic complexity를 직접 해결하지는 않으면서 transformer를 더 긴 sequence에 적용할 수 있도록 만드는 기법을 개발하는 데 목표를 둔 연구들입니다. 전자는 각 query가 집중하는 key의 양을 제안하여 복잡도를 줄이며, 후자는 self-attention 메커니즘의 근본적인 복잡도를 변겨앟지 않고 transformer가 처리할 수 있는 sequence의 길이를 늘립니다.
본 연구에서는 self-attentiond의 fast approximation인 clustered attention을 제안합니다. clustered attention은 queries들 사이의 유사성을 활용하여 이들을 그룹화함으로써 연산 비용을 줄입니다. 구체적으로, locality-sensitive hashing, LSH와 K-means를 사용하여 빠른 군집화를 수행하고 각 군집당 attention을 단 한 번만 계산합니다. 이는 군집의 개수가 고정되어 있을 때 linear complexity를 가져옵니다.
게다가, 우리는 각 군집에서 가장 높은 attention을 받는 keys들을 별도로 고려함으로써 근사의 품질을 더욱 향상시킬 수 있음을 보여줍니다. 마지막으로, full attention 대비 해당 논문의 근사 품질에 대한 이론적 한계(theoretical bounds)를 제공하며, 해당 모델이 pretrained transformer의 inference에 적용되어 성능 손실을 최소화할 수 있음을 보여줍니다.
Related Work.
Attention improvements Before Transformers.
attention은 수년 동안 sequence modeling을 위한 신경망의 필수적인 구성 요소였습니다. 그러나, sequence 길이에 대한 quadratic complexity는 긴 sequence에 대한 적용 가능성을 가로막습니다. 이를 해결하기 위한 첫 번재 시도 중 하나는 Britz et al. 의 연구로, 입력 sequence의 정보를 더 적은 수의 vector로 aggregate하고 이 적은 수의 vectors들로 attention을 수행하여 attention 연산 속도를 높이고 메모리 요구 사항을 줄일 것을 제안했습니다. 그러나 이러한 input aggregation은 학습되긴 하지만 모든 sequence에 대해 일정하게 유지되는 '고정된 행렬'을 사용하여 수행되므로 모델의 expressivity를 크게 제한합니다. 이와 유사하게, Chiu & Raffel et al. 은 과거에서 미래로 monotonically attention함으로써 attention이 접근할 수 있는 요소의 양을 제한합니다. 즉, timestep i가 위치 j에 attend하면, timestep i+1은 j 이전의 어떤 위치에도 attention할 수 없습니다. attention 연산 속도를 높이기 위해 앞서 언급한 방법들은 각 layer가 주의를 기울이는 요소의 수를 제한하고 있음에 유의해야 합니다.
Non-asymptotic Improvements.
self attention의 quadratic complexity를 직접적으로 개선하는 데 초점을 맞추지 않고 transformer를 긴 sequence에 적용하려는 기법들도 존재합니다. 이때 중요한 연구는 Adaptive Attention Span Transformer와 Transformer-XL입니다.
Sykhbaartar et al. 은 timestep에 대한 상대적 거리 측면에서 self-attention context를 가장 가까운 샘플들(attention span)로 제한하여 self-attention의 연산의 시간 및 메모리 요구 사항을 모두 줄일 것을 제안합니다. 이는 모델이 필요할 경우 attention span을 늘릴 수 있도록 학습 가능한 파라미터가 있는 masking 함수를 사용하여 달성됩니다(* 이는 앞서 언급한 입력 sequence의 정보를 고정된 행렬을 이용해 적은 수의 vector로 aggregation할 때 발생할 수 있는 표현력 저하를 방지하기 위함이다).
반면, Transformer-XL은 segment 수준의 recurrent training을 도입하여 유효 sequence 길이를 늘리려고 시도합니다. 즉, 입력을 segment로 나누고 이전 segment와 현재 segment에 공동으로 attention합니다. 이는 새로운 relative positional encoding과 결합되어, 학습 중 사용된 segment의 길이보다 더 멀리 떨어진 위치까지 attention할 수 있는 모델이 만들어집니다.
*Note: 원래 transformer는 메모리 한계 때문에 긴 책을 읽을 때 한 번에 512단어(1개 segment)씩만 잘라서 읽습니다. 1~512번 단어를 다 읽고 나면, 그 기억을 완전히 지워버리고(리셋) 513~1024번 단어를 새로 읽습니다. 이를 context fragmentation이라고 부르는데, 513번 단어를 해석할 때 바로 앞인 512번 단어를 참고할 수 없는 치명적인 문제가 생기는 것입니다. 따라서 Transformer-XL은 이전 segment를 버리지 말고, cache 메모리에 저장합니다. 그래서 현재 페이지(513~1024)를 읽을 때, 메모리에 올려둔 이전 페이지(1~512)의 정보(Key, Value)를 곁눈질로 함께 참고합니다. 이렇게 하면 문맥이 끊기지 않고 물 흐르듯 이어집니다.
*Note: 하지만 이전 segment와 함께 참고하면 위치 번호에 충돌이 생깁니다. 기존 transformer는 단어마다 절대적인 번호표(absolute positional encoding)를 붙입니다. 만약 이전 segment와 현재 segment를 같이 놓으면, 이전 segment의 1번 단어와 현재 segment의 1번 단어가 헷갈리게 됩니다. 그래서 절대적인 번호표를 떼버리고, 상대적 위치만 따지는 relative positional encoding을 사용하는 것입니다.
두 접근법 모두 효과적임이 입증되었지만, self-attention의 근본적인 한계는 여전히 남아있습니다. N timestep 떨어져 있는 요소에 attention하려면 여전히 O(N^2)의 메모리와 연산이 필요합니다. 대조적으로, 본 논문에서 제안한 모델은 full-attention에서의 약간의 오차를 감수하는 대신 linear asymptotic complexity로 개선하는 trade-off를 선택합니다. 이를 통해 긴 sequence의 처리가 가능해집니다.
Improvement in Asymptotic Complexity.
Child et al. (Sparse Transformer)은 self-attention 메커니즘을 local 및 strided attention으로 분해합니다. local attention은 가장 가까운 C개의 위치 사이에서 계산되고, strided attention은 서로 C단계 떨어져 있는 위치들 사이에서 계산됩니다. C가 √N으로 설정될 때 전체 복잡도는 메모리와 연산 시간 모두의 측면에서 O(N√N)이 됩니다. 앞서 언급한 분해를 사용하면 임의의 위치에 attend하기 위해 두 개의 self-attention layers가 필요합니다. 게다가 이 분해는 고정되어 있으며 데이터와 무관(data independent)합니다. 이는 특정 신호(예: 이미지)에 대해서는 직관적일 수 있지만 대부분의 경우 임의적(arbitary)입니다. 대조적으로, 제안된 모델은 수동으로 설계된 분해 매커니즘 없이도 유사한 입력 queries들을 자동으로 그룹화합니다. 더욱이 제안된 모델은 정보가 항상 모든 위치에서 다른 모든 위치로 흐릅니다(sparse attention은 고정된 위치에 대해서만 정보가 흐른다).
*Note: Sparse Transformer는 NxN 번을 전부 계산하는 것이 너무 무거우니, 연산을 두 가지 패턴으로 쪼개서 띄엄띄엄 계산하자고 제안합니다. 데이터 길이 N = 100이고, stride C = 10으로 설정했다면:
- Local Attention: 내 주변(가장 가까운 C개)만 집중해서 봅니다. (나를 기준으로 앞뒤 10개 단어만 확인)
- Strided Attention: 넓은 시야를 확보하기 위해 징검다리처럼 듬성듬성(C단계 떨어져서) 봅니다. (10번째, 20번째, 30번째 단어만 확인)
이렇게 하면 원래 100개를 다 봐야 했던 단어가, 로컬 10개 + 스트라이드 10개 = 총 20개(2√N)만 봐도 됩니다. 하지만 무조건 10칸마다 한 번씩 쳐다보게 고정(Fixed)한다는 문제가 있습니다. 즉, data independent하게 패턴이 결정된다는 문제입니다.
Set Transformers는 길이 N의 입력 sequence X와 inducing points라고 불리는 학습 가능한 파라미터 집합 I 사이의 attention을 계산하여 길이 M (M<<N)의 새로운 sequence H를 얻습니다. 그런 다음 이 새로운 sequence H는 X와 attention을 계산하여 출력 표현을 얻는 데 사용됩니다. 고정된 M에 대해, 점근적 복잡도는 sequence 길이에 대해 linear가 됩니다. inducing points는 task-specific한 어떤 전역적인 구조를 encoding할 것으로 기대됩니다. 그러나 이는 각 attention layer마다 추가적인 모델 파라미터를 도입합니다. 이와 대조적으로 본 논문의 방법이 파라미터 수의 증가 없이 입력을 길이가 더 짧은 고정된 sequence로 투영하기 위해 clustering을 사용합니다. 더욱이, 본 방법은 동일한 점근적 복잡도를 가질 뿐만 아니라, 추가 학습 없이 사전 학습된 모델의 추론 속도를 높이는 데에도 사용될 수 있음을 보여줍니다.
Scaling Attention with Fast Clustering.
Vanilla Attention.
길이 N의 임의의 sequence에 대해, transformer에서 사용되는 표준 attention mechanism은 Vaswani et al. 이 도입한 dot product attention입니다. 표준 표기법을 따라, attention 행렬 A은 NxN을 다음과 같이 정의합니다:

여기서 Q는 NxDk는 queries, K는 NxDk는 keys를 나타냅니다. softmax()는 row 단위로 적용됩니다. attention 가중치 A와 valeus V(NxDv)를 사용하여, 다음과 같이 새로운 값 V_hat을 계산합니다:

위에서 설명한 attention에 대한 직관적인 이해는, Q, K, V가 주어졌을 때 이전 값들의 가중 평균(weighted average)으로 새로운 값 V_hat을 생성하며, 여기서 가중치는 attention matrix A에 의해 정의됩니다. 위 softmax 식을 계산하는 데에는 O(N^2Dk)의 연산이 필요하며, AV를 통해 attention output을 계산하는 데에는 O(N^2Dv)의 연산이 필요하므로 전체 점근적 복잡도는 O(N^2Dk + N^2Dv)가 됩니다.
Clustered Attention.

모든 queries들에 대해 attention matrix를 계산하는 대신, queries들을 C개의 clusters으로 그룹화하고 오직 이 군집들에 대해서만 attention을 계산합니다. 그런 다음, 같은 군집에 속한 queires들에 대해서는 동일한 attention 가중치를 사용합니다. 결과적으로 attention 연산은 이제 O(NCDk)가 되며, 여기서 C << N입니다.
보다 공식적으로 쓰면, queries Q를 겹치지 않는 C개의 군집으로 나누는 partitioning 행렬 S ∈ {0,1} NxC 를 정의합니다. 즉, i번째 queries가 j번째 군집에 속하면 Sij = 1이고, 그렇지 않으면 0입니다. 이 분할을 사용하여 이제 clustered attention을 계산할 수 있습니다. 먼저 다음과 같이 군집의 centroids를 계산합니다:

여기서 Q^c_j는 j번째 군집의 centroid입니다. centroid matrix를 Q^c는 CxDk라고 하면, 이제 Q^c가 queries인 것처럼 clustered attention을 계산할 수 있습니다. 즉, clustered attention matrix A^c CxN을 계산합니다:

그리고 새로운 값 V_hat^c를 계산합니다:

마지막으로, i번째 query의 값은 그것과 가장 가까운 중심점의 값이 됩니다. 즉:

위의 분석에서 알 수 있듯이, attention 가중치와 values들의 가중 평균을 각 군집당 한 번씩만 계산하면 됩니다. 그런 다음 동일한 군집에 속한 모든 queries에 동일한 값을 broadcast하여 나눠줄 수 있습니다. 이를 통해 dot products의 횟수를 각 queries당 N번에서 각 군집당 C번으로 줄일 수 있으며, 그 결과 O(NCDk) + O(CNDv)의 점근적 복잡도를 얻게 됩니다.
이때 실제로는 multi-head attention을 사용하므로, 동일한 군집에 속했던 두 queries가 다른 attention head에서는 다르게 군집화될 수 있습니다. 게다가 attention layer의 출력은 residual connection을 포함합니다. 이로 인해 동일한 군집에 속했던 두 queries라도 최종 출력 표현은 달라질 수 있습니다. residual connection과 multi-head attention의 결합된 효과는 후속 layer에서 완전히 새로운 군집화 패턴이 나타날 수 있게 해줍니다.
*Note: 같은 군집에 속한 queries들의 feature value를 centroid에 대한 attention value로 broadcast하는 것은 feature의 표현력이나 다양성의 측면에서 문제가 될 수 있음을 residual connection과 multi-head attention을 통해 극복합니다.
위 내용을부터, queries들을 군집으로 그룹화하는 것이 self-attention의 연산 속도를 높일 수 있음을 보여주었습니다. 하지만 이전 분석에서는 군집화가 attention 가중치 A에 미치는 영향을 고려하지 않았습니다. 이를 해결하기 위해, 본 논문에서는 근사 오차에 대한 bound를 도출합니다. 구체적으로, attention의 차이가 queries들 사이의 유클리드 거리에 대한 함수로 bound될 수 있음을 보여줍니다:

||Qi - Qj|| ≤ ε 을 만족하는 두 queries Qi와 Qj가 주어졌을 때, K의 spectral norm ||K||의 ε으로 attention 오차가 bounded 됩니다. softmax() 함수가 1 미만의 Lipschitz constant를 가진다는 점을 고려한다면, proposition 1. 은 유클리드 공간에서 가까운 queries들은 유사한 attention 분포를 가진다는 것을 보여줍니다. 결과적으로, j번째 군집에 할당된 i번째 query에 대한 attention 근사 오차는 그 query와 cluster centroid Q^c_j 사이의 거리에 의해 bounded 될 수 있습니다.
지금까지의 논의를 통해, 대표적인 queries 집합이 주어지면 더 적은 연산으로 attention을 근사할 수 있음을 보여주었습니다. 따라서 이제 문제는 이 queries 집합을 어떻게 찾을 것인가가 됩니다. K-means clustering은 군집 멤버들 간의 거리 제곱합을 최소화하므로, 위 분석에 비추어 볼때 최적이 됩니다. 그러나 길이 N의 sequence에 대해 K-means 최적화 문제를 위한 Lloyd's algorithm을 1회 반복하는 엇은 O(NCDk)의 점근적 복잡도를 가집니다.
거리 계산 속도를 높이기 위해, queries들에 지역성 기반 해싱(locality-sensitive hashing, LSH)을 적용한 다음 hamming space에서 K-means를 수행합니다. 구체적으로, queries들을 hash하기 위해 무작위 투영(random projections)의 부호를 사용하고, 이어서 hamming distance를 metrix으로 사용하는 K-means clustering을 수행합니다. 이는 O(NCL + CBL CDkB)의 점근적 복잡도를 가져오며, 여기서 L은 Lloyd's algorithm 반복 횟수이고 B는 hashing에 사용된 bits 수입니다.
Improving Clustered Attention.

이전 섹션에서는 clustered attention이 softmax attention에 대한 빠른 근사를 제공함을 보여주었습니다. 이 섹션에서는 각 cluster에 대해 가장 높은 attetnion을 받는 keys들을 별도로 고려함으로써 이 근사를 어떻게 더 향상시킬 수 있는지 논의합니다. 이 과정의 중요성을 직관적으로 이해하기 위해서는, 특정 queries들에 대해 원래 attention이 낮아야 할 key가, centroid에 의한 근사 과정에서 억울하게 높은 attention을 받게 되는 시나리오를 고려해보면 충분합니다. 이러한 현상은 군집의 개수가 너무 적거나 K-means 알고리즘이 수렴에 실패했을 때 발생할 수 있습니다. 아래에서 논의할 변형 기법은 바로 이러한 한계를 해결합니다.
clustered attention A^c를 계산한 후, 각 군집에 대해 가장 높은 attention을 받는 top-k의 keys들을 찾습니다. 그런 다음 해당 군집에 속한 각 개별 queries들에 대해, 이 top-k keys들에 한해서만 attention 근사를 개선하는 것입니다. 이를 위해, 먼저 이 군집에 속한 모든 queries들과 top-k keys들 사이에 dot product attention을 직접 계산합니다. 그런데 임의의 queries에 대해 이 top-k keys들에 대해서만 계산된 attention의 합은 1이 될 것입니다. 이는 기존에 계산해 둔 clustered attention 값을 이 값으로 직접 대체할 수 없음을 의미합니다. 이 문제를 해결하기 위해, 대체를 수행하기 전에 새로 계산된 attention 값에 해당 군집화된 attention이 이 top-k keys들에게 할당했던 기존 확률 질량의 합을 곱하여 스케일을 조정합니다.
j번째 군집의 top-k key에 i번째 key가 포함되면 Tji = 1, 그렇지 않으면 0인 T matrix ∈ {0,1} CxN을 도입합니다. 그런 다음 j번째 군집의 top-k keys들에 대한 확률 질량 mj를 계산하면 다음과 같습니다:

이제 개선된 attention matrix 근사 A^t NxN 을 다음과 같이 나타낼 수 있습니다:

위 식에서 i는 j번째 군집에 속한 i번째 query를 나타내며, 명확성을 위해 √Dk는 생략되었습니다. 주어진 군집의 top-k key에 속하지 않는 keys들에 대해서는 clustered attention 값을 그대로 선택하며, 나머지(top-k keys)들에 대해서는 query와 top-k keys들 사이의 직접적인 내적 attention 비율에 따라 확률 질량 mj를 재분배합니다. 이에 대응하는 새로운 값 V_hat은 A^t와 values 행렬의 곱입니다:

위 식은 clustered attention 연산과 두 개의 sparse dot products(모든 queries와 top-k key 사이의 내적 1회 + top-k attention matrix와 해당 값들 사이의 내적 1회)로 분해될 수 있습니다.

이 명제에 대한 증명은 appendix에 있으며, improved clustered attention이 항상 clustered attention보다 full-attention을 더 잘 근사하는 것이 자명해집니다.
Experiments.
(... 원본 논문 참조 ...)