음...최근 최대한 셋을 돌고 코포랑 앳코더를 돌아보고 있는데 자꾸 망치니까 정신상태가 안좋아지는 것 같아서 당분간은 셋을 도는 것에 주력하지 않고 웰노운 아닌 웰노운 알고리즘들을 배워보고자 한다. 뭔가 이론을 배우는 것 자체는 힐링되는 느낌이라 좋은 것 같다. 이제부터 배우는 족족 최대한 빠르게 "웰노운" 알고리즘들을 정리해서 올려볼 것이다. 만약 내가 블로그를 올리지 않는다면 독촉바란다. 아무튼, 이번에 다룰 대상은 일반적으로 두 다항식을 빠르게 곱하는 것으로 잘 알려져있는 FFT와 그 변형인 NTT, 온라인 FFT, FWHT에 대해 다룰 것이다.
1. FFT
해당 절에서는 이 문제를 해결하며, 이해하는 것을 목표로 두어도 괜찮을 것이다: 15576번: 큰 수 곱셈 (2) (acmicpc.net)
일단 FFT는 가장 기본적인 내용이라고 볼 수 있는데, 이를 알기 이전에 특정한 컨셉에 대해서 알아볼 필요가 있다.
평면 위에 x좌표가 서로 다른 n개의 점이 있을 때, 이를 지나는 \(n-1\)차 다항식은 유일하게 존재한다.
만약 그런 다항식이 존재한다면 유일하다는 사실은, 일반적으로 어떤 객체가 유일한지를 증명하는 방법처럼 만약 두개가 존재한다고 가정한 다음, 그 차이에 대해서 분석하면 된다. 그렇다면, 존재성을 어떻게 판별하는가? 이는 라그랑주 보간법에 의해 성립하는데, 자세한 내용은 직접 찾아보면 될 것이다.
DFT라는 것은 이산 푸리에 변환으로, 어떤 \(n-1\)차 이하의 다항식이 주어졌을 때 그 다항식들에 n개의 수를 넣었을 때의 출력값을 계산하는 것이다. 만약 우리가 전부 다른 n개의 수를 집어넣었다면, 어떤 다항식의 이산 푸리에 변환만 봐도 그 다항식이 무엇이었는지 유추할 수 있다. 그 어떤 수를 집어넣어도 위의 조건이야 만족하겠지만, 일단 n th root of unity를 집어넣자. 즉, \(1=x^n\)의 n개의 서로 다른 해를 집어넣겠다는 것이다. 이 세팅이 나중에 얼마나 강력한 역할을 할지 기대해도 좋다.
만약 우리가 나이브하게 DFT를 계산하고자 한다면, 하나의 함숫값에 대해 \(O(n)\)이 소요되므로 총 시간복잡도는 \(O(n^2)\)정도가 될 것이다. 그러나 이는 우리의 목표를 달성하기에는 턱없이 느리다. 더 빠르게 할 수는 없을까?
분할정복을 활용하면 무려 \(O(nlogn)\)에 목표를 달성할 수 있다. 일단 우리가 함숫값을 구하고자 하는 다항식이 \(P(x)=a_0+a_1x+a_2x^2+...+a_{2n-1}x^{2n-1}\)이었다고 하자. 그리고 \(A(x)=a_0+a_2x+...,B(x)=a_1+a_3x+...\)와 같이 홀수차수와 짝수차수로 이루어진 다항식들을 생각하자. \(P(x)=A(x^2)+xB(x^2)\)임은 자명하다. 만약 우리가 A와 B에 대한 이산 푸리에 변환을 먼저 구했다고 생각하자. 그렇다면, P(w)를 구하기 위해선 \(A(w^2)+wB(w^2)\)를 계산하면 될 것이다. A와 B 테이블에 대해 생각해보면, 우리는 해당 방법으로 \(P(w^{n-1})\)까지를 계산 가능하다. 그런데, w가 root of unity라는걸 고려해보면 그 반댓편 값도 계산 가능하다. 즉, \(P(w^{x+n})=A(w^2)-wB(w^2)\)가 된다. 따라서 P의 모든 테이블을 채울 수 있게 되고, 분할정복적 관점에서 보면 전체 다항식에 대한 답을 구하기 위해선 \(O(nlogn)\)의 시간만 있으면 충분하다.
이 fft는 물론 재귀적으로 분할정복을 구현할 수도 있지만, 재귀로 짤 시 시간이 걸린다는 단점이 있다. 비재귀적으로 짜기 위해선 미리 각 인덱스들이 어느 위치에 있어서 분할정복이 잘 굴러갈 수 있는지를 감안해서 다항식의 순서를 잘 바꿔주고, 그 다음에 길이를 천천히 올려가며 계산하면 된다. 아래는 내가 현재 사용하고 있는 비재귀 fft이다.
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef double lf;
typedef complex<double> cd;
const lf PI=acos(-1);
void fft(vector<cd> &a,bool invert)
{
ll n=a.size();
ll i,j;
for(i=1,j=0;i<n;i++)
{
ll bit=n>>1;
for(;j&bit;bit>>=1)
j^=bit;
j^=bit;
if(i<j)
swap(a[i],a[j]);
}
for(ll len=2;len<=n;len<<=1)
{
lf ang=2*PI/len*(invert?-1:1);
cd wlen(cos(ang),sin(ang));
for(i=0;i<n;i+=len)
{
cd w(1);
for(j=0;j<len/2;j++)
{
cd u=a[i+j],v=a[i+j+len/2]*w;
a[i+j]=u+v;
a[i+j+len/2]=u-v;
w*=wlen;
}
}
}
if(invert)
for(cd &x:a)
x/=n;
}
만약 invert가 true라면, 역 dft를 하게 되는데, 이는 함숫값들을 통해 원래의 다항식을 찾아낸다는 뜻이다. 암튼 이렇게 짜면 된다.
2. NTT
이 주제에 대한 좋은 intuition은 14882번: 다항식과 쿼리 (acmicpc.net)이다.
일단 이 문제를 보면, 단순한 fft로는 무리가 있다는 사실을 알 수 있다. 왜냐하면 특정한 다항식에 함숫값을 대입했을 때의 문제를 해결해야 하기 때문이다. 해당 문제는 사실 NTT의 근본적 정의를 통해서 구할 수 있는데, NTT는 정수론적 푸리에 변환이다.
이 주제에 대해 배우기 전, 원시근이라는 개념에 대해 알아야 한다. 임의의 소수 p와 2p, 2,4에 대해서는 그 기약잉여계 내에서 적절한 정수 g가 존재하여, 모든 기약잉여계의 원소 x에 대해서 \(g^k=x\)인 k가 존재한다. 특히, \(g^{p-1}=1\)이다. 무언가 느낌이 오지 않는가? 잘 생각해보면 원시근이 root of unity와 동일한 작용을 한다는 사실을 알 수 있다. 사실 이에 대해 디리클레 지표등을 통해서 그 사이의 대응관계를 규명할 수 있는 방법이 있는 것으로 알지만, 굳이 여기에서 설명할 필요는 없다.
아무튼, 예를 들어 쓸만한 소수인 65537을 가져와보자. 아주 유명한 페르마 소수라는 사실을 알 수 있다. 여기에 있는 원시근 g를 하나 잡고 나면, fft에서 우리가 사용하였던 root of unity를 g의 특정한 거듭제곱으로 대체하여 사용할 수 있다. 즉, 우리는 크기가 65536인 다항식을 곱할 수 있다는 뜻이다.
다른 예시로, 다항식과 쿼리에서 나온 786433에 대해 분석해보자. 이 수는 2^18*3+1인데, 만약 262144정도 크기의 다항식에 대해 ntt를 돌리고 싶다면, 786433의 원시근인 10을 하나 잡고, 이를 세제곱한 1000을 기본단위로 ntt를 돌리면 된다. 왜냐하면, 1000은 법 786433에서 262144의 위수를 가지기 때문에, 262144 th root of unity와 동질적이기 때문이다. 흠...하지만 이걸로는 우리의 문제를 해결하기엔 부족하다.
분명 다항식과 쿼리에서는 법 786433 내의 모든 원소들에 대한 함숫값을 계산할 것을 요구하고 있다. 하지만 우리의 방식대로라면, 1000을 거듭제곱하여 표현할 수 있는, 전체의 1/3정도에 해당되는 값을에 대한 답밖에 알 수가 없다. 여기에서 특별한 방법을 사용하게 된다.
이전의 fft에서는 다항식을 항상 2개로 나누었지만, 3n-1차의 다항식 P에 대해서 이렇게도 할 수 있다:
$$P(x)=a_0+a_1x+a_2x^2+....,A(x)=a_0+a_3x+a_6x^2+...,B(x)=a_1+a_4+a_7x^2+...,C(x)=a_2+a_5x+a_8x^2+...$$
이제 답을 계산하고 싶다면, 다음과 같은 절차를 따르면 된다:
$$P(x)=A(x^3)+xB(x^3)+x^2C(x^3)$$
$$P(x*w^n)=A(x^3)+xw^nB(x^3)+x^2w^{2n}C(x^3)$$
$$P(x*w^2n)=A(x^3)+xw^{2n}B(x^3)+x^2w^nC(x^3)$$
즉, 3분할을 해도 별 상관이 없으며, 물론 이는 fft에 대해 적용해도 성립한다. 이 방법을 쓰면 786433의 모든 원소들에 대한 답을 계산이 가능할 것이다. 난 위 문제에 대한 코드를 다음과 같이 작성하였다.(디버깅용 주석은 무시하라)
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
const ll mod=786433;
const ll root=1000;
const ll root_1=710149;
void fft(vector<ll> &a,bool invert)
{
ll n=a.size();
ll i,j;
for(i=1,j=0;i<n;i++)
{
ll bit=n>>1;
for(;j & bit ; bit>>=1)
j^=bit;
j^=bit;
if(i<j)
swap(a[i],a[j]);
}
for(ll len=2;len<=n;len <<=1)
{
ll wlen=(invert?root_1:root);
for(i=len;i<262144;i<<=1)
wlen=(wlen*wlen)%mod;
for(i=0;i<n;i+=len)
{
ll w=1;
for(j=0;j<len/2;j++)
{
ll u=a[i+j],v=(a[i+j+len/2]*w)%mod;
a[i+j]=(u+v)%mod;
a[i+j+len/2]=(u-v+mod)%mod;
w=w*wlen%mod;
}
}
}
if(invert)
{
for(ll &x:a)
x=(x*(mod-3))%mod;
}
}
ll n,i,j,m,x,y,z,d[1100000],ans[1100000],t;
vector<ll> a,b,c;
ll g(ll x,ll y)
{
ll r=mod-2,s=1;
while(y>0)
{
if(y%2==1)
{
s=s*x%mod;
}
x=x*x%mod;
y/=2;
}
return s;
}
int main()
{
a.resize(262144);
b.resize(262144);
c.resize(262144);
scanf("%lld",&n);
for(i=0;i<=n;i++)
{scanf("%lld",&x);
if(i==0)
ans[0]=x;
if(i%3==0)
a[i/3]=x;
if(i%3==1)
b[i/3]=x;
if(i%3==2)
c[i/3]=x;
}
fft(a,false);
fft(b,false);
fft(c,false);
// printf("%lld %lld %lld\n",a[1],b[1],c[1]);
ll w=1;
for(i=0;i<262144;i++)
{
d[i]=(a[i]+b[i]*w+c[i]*w%mod*w)%mod;
d[i+262144]=(a[i]+b[i]*w*g(10,262144)+c[i]*w*w%mod*g(10,524288))%mod;
d[i+524288]=(a[i]+b[i]*w*g(10,524288)+c[i]*w*w%mod*g(10,262144))%mod;
w=w*10%mod;
}
//w=1000;
//printf("%lld\n",d[0]);
w=1;
for(i=0;i<786432;i++)
{
ans[w]=d[i];
w=w*10%mod;
}
scanf("%lld",&t);
for(i=1;i<=t;i++)
{
scanf("%lld",&x);
printf("%lld\n",ans[x]);
}
}
이제 임의의 분할을 활용하는 fft의 시간복잡도를 분석해보자. 일단 어떤 n이 주어지고 나면, 우리는 아마 n을 소인수분해하며 그 인수들을 기반으로 다항식을 분할하게 될 것이다. 만약 우리가 크기 p의 분할을 했다면, 해당 절차에서는 $$O(np)$$정도의 시간복잡도가 소요된다. 즉, 전체적으로 대충 다음과 같은 시간복잡도가 될 것이다: $$O(nlogn+n(n의 최대 소인수))$$
보통 998244353, 786433과 같이 그 크기에 비해 2의 거듭제곱이 많이 들어가는 소수들을 ftt-friendly prime이라고 하는 것 같은데, 사실은 소인수의 최댓값이 꽤 작은 경우에도 fft가 잘 돌아간다는 사실을 확인 가능하다.
3. online FFT
이제부터 많이 생소한 알고리즘이 된다. 이 online FFT는 이전까지의 모든 항을 전부 알아내야 다음 항을 알아낼 수 있는 경우, 즉 곱하는 두 다항식이 결과와 연관이 있거나 하는 등의 상황에서 사용될 수 있다. 가장 간단한 예시를 들자면, 카탈란 수를 들 수 있을 것이다. 카탈란 수는 다음 점화식을 만족시키기 때문이다:
$$C_n=C_0C_{n-1}+...+C_{n-1}C_0$$
즉, 두 다항식을 곱하고 싶은데 특정한 항을 모르고, 이전까지의 답을 계산해야만 답을 알 수 있는 경우에 사용될 수 있는 알고리즘이다. 연습문제로는 18743번: Bin (acmicpc.net)가 있다는 것 같은데, 아직 안풀어봤다. 곧 짤듯
아무튼, 이걸 어떻게 할 수 있을까? 대략적인 아이디어만을 소개하겠지만, 해당 아이디어를 적당히 변형하고 응용한다면 선형 점화식등의 여러 객체들에 대해서 적용이 가능하다.
핵심적인 아이디어는 분할정복이다. 구간 [L,R)에 대해 답을 구하고 싶다고 하자. 이때, 답이라 함은 R-1까지의 모든 항을 알아낸다는 의미이다. 그리고 중요한 가정은, [L,R) 구간에 대해 답을 계산하는 시점에서 이미 L-1까지는 전부 "답을 알고 있으며, 심지어는 곱하고자 하는 다항식들을 곱한 결과를 적어도 L-1까지는 알고 있어야" 한다는 점이다. 이제 분할 단계에서 [L,M)에 대한 답을 알아냈다고 하자. 우리는 현재 [L,M) 구간 상에 있는 계산하고자 하는 값들의 항에 대해 알고 있지만, 실제로 다항식을 곱한 결과에 대해 알아내지는 못했다. 이를 알아내기 위해선, 아주 나이브하게 다항식을 전부 곱하는 방법이 있을 것이다. 예를 들어 우리가 계산하고 있는 "답"이 a라는 수열이고, a를 구하기 위해 곱해야 하는 것이 a*b라고 할 수 있다. 임의의 수열 x에 대해, x:L을 L까지의 원소들을 딴 다항식이라고 하자. 그렇다면 현재 알고 있는 것은 (a:L-1)*(b:L-1)이며, 구해야 하는답은 (a:M-1)*(B:M-1)이다. 이를 계산하기 위해서는 ((a:M-1)-a:(L-1))*((b:M-1)-(b:L-1))+(a:(L-1))*((b:M-1)-(b:L-1))+(a:L-1)*(b:L-1)+(b:(L-1))*((a:M-1)-(a:L-1))+(b:L-1)*(a:L-1)일...것이다(식정리가 틀렸을수도 있는데 쨌든 핵심을 알려주면 당신도 계산할 수 있을 것이다.). 핵심은, L-1들끼리의 곱은 이미 우리가 답을 알고 있다는 점, 그리고 x:M-1과 x:L-1의 차이는 M-L이며, 그 전체 다항식의 최저차항이 L정도 된다는 것이다. R차 이상은 현재 다루는 구간에 대해서는 무의미하기 때문에, 그 옆에 곱해지는 다항식에 대해서는 단지 R-L차항까지만 보면 되며, 따라서 이들을 곱하는데에 걸리는 시간복잡도는 (R-L)log(R-L)에 지배당한다. 따라서 분할정복적 구조가 성립할 여지가 존재하게 되며, 전체 시간복잡도는 대략적으로 O(nlog^2n)이 된다. 음 뭔가 나도 잘 이해한건진 모르겠고 한번도 짜본적도 없어서 나중에 짜봐야 알 것 같다. rkm님 블로그에 좋은 설명이 있는 것 같다. 4월의 PS 일지 - Part 4 :: rkm0959 (tistory.com)
4. FWHT
일단 다음 포스트를 보고 글을 작성했고, 많은 부분을 참고했기 때문에 읽어보면 좋을 것이다: FWHT (Fast Walsh-Hadamard Transform) (tistory.com) 개인적으로 다른 변환들에 비해 가장 "이질적인" 아이디어다. 우선, 이 문제에 대한 기본적인 intuition은 다음이다: 25563번: AND, OR, XOR (acmicpc.net)
이 문제에 대해 간략히 요약하자면 결국 어떤 수열에 대한 convolution을 계산하는데, 어떤 값이 더해지는 인덱스가 어떻게 계산되는지에 따라 convolution의 종류가 달라진다고 볼 수 있다.
본 포스팅에서 다루는 convolution은 다음과 같은 꼴이다: $$(a*b)_k=\sum_{i*j=k}^{}a_ib_j$$
즉, 어떤 수열 a와 b에 대해서 *라는 연산을 통해 계산되는 인덱스에 두 원소의 곱을 더해주는 것이라 할 수 있다. 이때, 이 연산이 덧셈이라면 일반적인 다항식의 곱셈이 된다. FWHT는 이러한 원소들이 and,or,xor과 같은 비트연산일때, 그 convolution을 빠르게 구해주고, 그 역연산 또한 구해주도록 하는 변환이다.
사실 FWHT는 각 비트연산의 종류에 따라 다르게 정의된다. 일단 가장 간단한 or에 대해 먼저 분석해보자. or 위에서의 FWHT의 k번째 원소는 해당 수열에서 k랑 or한 값이 k가 되는 인덱스의 원소들을 전부 더한 값으로 정의된다. 만약 sos_dp에 대해서 알고 있다면, 그것과 정확히 동일한 개념이라는 사실을 눈치챌 수 있을 것이다. 또한, sos dp가 계산되는 방식은 그 역연산을 정확하게 따라할 수 있기 때문에, 역연산 또한 존재한다는 사실을 알 수 있다.
무엇보다 가장 놀라운 것은, FWHT는 or-convolution의 결괏값들을 단순한 곱셈만으로 처리할 수 있도록 해준다는 것이다.
즉, 다음과 같은 성질이 성립한다:
어떤 두 수열 a와 b에 대해,(FWHT(a))_k(FWHT(b))_k=(FWHT(a*b))_k
이에 대한 증명은 다음과 같은 수식을 본다면 이해할 수 있을 것이다.
$$FWHT(a)_kFWHT(b)_k=\sum_{i\vee k=k}^{}\sum_{j\vee k=k}^{}a_ib_j=\sum_{j\vee i\vee k=i}^{}a_ib_j=\sum_{l\vee k=k}^{}\sum_{i\vee j=l}^{}a_ib_j=FWHT(a*b)_k$$
이 결과가 의미하는 것은, sos dp가 단순하게 부분집합의 합을 구하는데에서 그치지 않고, 일종의 변환으로 작용하여 or 사이의 convolution을 오직 결과를 곱하는것만으로 빠르게 계산하고 역연산도 지원해주는 아이디어로서 작용한다는 것이다.
아무튼, 이제 and convolution에 대해 다룰 것인데, 사실 and가 가지고 있는 구조는 or의 구조와 매우 유사하기 때문에 and_FWHT가 j^i=i를 만족하는 모든 j에 대한 합이라는 점과, 그것이 or때와 거의 대부분의 성질을 공유한다는 점만 짚고 넘어가겠다.
이제 xor인데, 이 경우에는 FWHT의 원소간 곱셈이 convolution의 결괏값과 일치함을 증명하기가 쉽지 않다. 단지, 다음과 같은 FWHT가 그 역할을 수행할 수 있음만을 언급하겠다.
FWHT: i번 원소를 계산하고 싶다고 하자. 그렇다면, i와 or한 비트의 개수가 짝수인 인덱스의 값들을 전부 더하고, 그렇지 않은 것들은 전부 뺀다.
아무튼 위의 문제를 어떻게 풀어야 하는지에 대해 생각해보자. 잘 생각해보면, 배열에 각각 and,or,xor convolution을 스스로에게 취하고 나서 i,j가 같은 등의 예외처리만 조금 해주면 답이 나온다는 사실을 알 수 있다. 따라서 단순히 자신의 FWHT를 계산하고, 각 항들을 제곱해준 뒤 역 FWHT를 취해주면 답을 얻을 수 있다.
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
void fwht_or(vector<ll> &a,bool invert)
{
ll dir=(invert?(-1):(1));
ll n=a.size();
for(ll s=2,h=1;s<=n;s<<=1,h<<=1)
{
for(ll l=0;l<n;l+=s)
{
for(ll i=0;i<h;i++)
{
a[l+i+h]+=dir*a[l+i];
}
}
}
}
void fwht_and(vector<ll> &a,bool invert)
{
ll dir=(invert?(-1):1);
ll n=a.size();
for(ll s=2,h=1;s<=n;s<<=1,h<<=1)
{
for(ll l=0;l<n;l+=s)
{
for(ll i=0;i<h;i++)
{
a[l+i]+=dir*a[l+i+h];
}
}
}
}
void fwht_xor(vector<ll> &a,bool invert)
{
ll n=a.size();
for(ll s=2,h=1;s<=n;s<<=1,h<<=1)
{
for(ll l=0;l<n;l+=s)
{
for(ll i=0;i<h;i++)
{
ll t=a[i+l+h];
a[i+l+h]=a[i+l]-t;
a[i+l]+=t;
if(invert)
{
a[i+l]/=2;
a[i+l+h]/=2;
}
}
}
}
}
vector<ll> a,b;
ll n,i,k,s;
int main()
{
a.resize(1048576);
scanf("%lld %lld",&n,&k);
for(i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
}
b.clear();
//printf("?");
b.resize(1048576);
for(i=1;i<=n;i++)
{
b[a[i]]++;
}
//printf("!");
fwht_and(b,false);
//printf("?");
for(i=0;i<b.size();i++)
{
b[i]*=b[i];
}
fwht_and(b,true);
for(i=1;i<=n;i++)
{
b[a[i]]--;
}
for(i=0;i<b.size();i++)
{
b[i]/=2;
}
printf("%lld ",b[k]);
b.clear();
b.resize(1048576);
for(i=1;i<=n;i++)
{
b[a[i]]++;
}
fwht_or(b,false);
for(i=0;i<b.size();i++)
{
b[i]*=b[i];
}
fwht_or(b,true);
for(i=1;i<=n;i++)
{
b[a[i]]--;
}
for(i=0;i<b.size();i++)
{
b[i]/=2;
}
printf("%lld ",b[k]);
b.clear();
b.resize(1048576);
for(i=1;i<=n;i++)
{
b[a[i]]++;
}
fwht_xor(b,false);
for(i=0;i<b.size();i++)
{
b[i]*=b[i];
}
fwht_xor(b,true);
b[0]-=n;
for(i=0;i<b.size();i++)
{
b[i]/=2;
}
printf("%lld",b[k]);
}
다음은 소스코드이다. 그냥 각 연산들에 대한 fwht가 성립하는 이유를 완벽히 이해하지 않고도, 위의 형식대로 fwht를 계산하면서 쓰기만 할 수도 있기 때문에 외워두면 쓸만할 것이다.
5. 느낀점
"변환"이라는 것이 얼마나 유용한 구조인지를 다시 한번 느꼈다. 일반적으로 수학에서도 푸리에 변환에서부터, z-변환, 라플라스 변환, 뫼비우스 변환등의 것들을 배웠었지만, 이렇게 비트 연산과도 같은, 어찌 보면 수학적으로 아주 근본적이진 않은 것 같은 체계 위에서도 변환을 통해 계산의 난해함을 해결하는 것이 인상적이라고 생각한다. 또한, sos dp의 특별한 성질에 대해서도 알 수 있었다.
아마 다음 주제는 대충 다항식가지고 장난질 치는 주제가 될 것 같다.