Logo UKBwyx的博客

博客

个人多项式 NTT 模版

2026-01-02 12:31:02 By UKBwyx

多项式科技(持续更新)

namespace poly {
    const int mod=998244353;
    long long w[1<<21];
    int rv[1<<21];
    pii exgcd(int x,int y) {
        if(y==0)return mp(1,0);
        pii t=exgcd(y,x%y);
        return mp(t.second,t.first-t.second*(x/y));
    }
    inline long long ksm(long long x,long long y) {
        long long res=1;
        while(y) {
            if(y&1)res=res*x%mod;
            x=x*x%mod,y>>=1;
        }
        return res;
    }
    inline void trans(vector<int>&a,int len,bool intt) {
        int n1=1<<len;
        a.resize(n1);
        for(int i=1; i<n1; i++) {
            rv[i]=rv[i>>1]>>1;
            if(i&1)rv[i]|=1<<len-1;
        }
        w[0]=1;
        for(int i=0; i<n1; i++)if(rv[i]<i)swap(a[i],a[rv[i]]);
        for(int i=0; i<n1; i++) {
            a[i]%=mod;
            if(a[i]<0)a[i]+=mod;
        }
        long long w1=ksm(3,mod-1>>len);
        if(intt)w1=(exgcd(mod,w1).second+mod)%mod;
        for(int t=1; t<=len; t++) {
            int t1=1<<t,t2=1<<t-1,s=n1>>t;
            long long w2=ksm(w1,s);
            for(int i=1; i<t2; i++)w[i]=w[i-1]*w2%mod;
            for(int i=0; i<n1; i+=t1) {
                for(int j=0; j<t2; j++) {
                    int u=i+j,v=i+j+t2;
                    int tt=a[u];
                    long long t=w[j]*a[v]%mod;
                    a[v]=tt-t;
                    if(a[v]<0)a[v]+=mod;
                    a[u]=tt+t;
                    if(a[u]>=mod)a[u]-=mod;
                }
            }
        }
        long long t=exgcd(mod,n1).second;
        if(intt)for(int i=0; i<n1; i++)a[i]=a[i]*t%mod;
        for(int i=0; i<n1; i++)if(a[i]<0)a[i]+=mod;
    }
    inline void tz(vector<int>&a) {
        int la=a.size();
        for(int i=a.size()-1; i>=1; i--) {
            if(a[i])break;
            la=i;
        }
        a.resize(la);
    }
    inline void qf(vector<int>&a) {
        for(int i=0; i<a.size(); i++)if(a[i])a[i]=mod-a[i];
    }
    vector<int>add(vector<int>&a,vector<int>&b) {
        int n=max(a.size(),b.size());
        vector<int>res;
        a.resize(n);
        b.resize(n);
        res.resize(n);
        for(int i=0; i<n; i++) {
            res[i]=a[i]+b[i];
            if(res[i]>=mod)res[i]-=mod;
        }
        return res;
    }
    vector<int>minus(vector<int>&a,vector<int>&b) {
        int n=max(a.size(),b.size());
        vector<int>res;
        a.resize(n);
        b.resize(n);
        res.resize(n);
        for(int i=0; i<n; i++) {
            res[i]=a[i]-b[i];
            if(res[i]<0)res[i]+=mod;
        }
        return res;
    }
    inline void mul(vector<int>&a,vector<int>&b) {
        for(int i=0; i<a.size(); i++)a[i]=(long long)a[i]*b[i]%mod;
    }
    inline int get_len(int x) {
        return log2(x)+1;
    }
    inline void prt(vector<int>&a) {
        for(int i=0; i<a.size(); i++) {
            cout<<a[i]<<" ";
        }
        cout<<"\n";
    }
    inline vector<int> conv(vector<int>a,vector<int>b) {
        int l=get_len(a.size()+b.size());
        trans(a,l,0);
        trans(b,l,0);
        mul(a,b);
        trans(a,l,1);
        return a;
    }
    void inv(vector<int>&a,int n) {
        vector<int>b;
        a.resize(n);
        for(int i=0; i<n; i++)if(a[i]<0)a[i]+=mod;
        b.pb(exgcd(mod,a[0]).second);
        for(int i=2,l=1; i<n<<1; i<<=1,l++) {
            vector<int>t,c;
            c=b;
            for(int j=0; j<c.size(); j++)c[j]=(c[j]<<1)%mod;
            t.resize(i);
            c.resize(i);
            for(int j=0; j<min(i,(int)a.size()); j++)t[j]=a[j];
            trans(t,l+1,0);
            trans(b,l+1,0);
            mul(b,b);
            mul(t,b);
            trans(t,l+1,1);
            for(int j=0; j<i; j++) {
                c[j]=(c[j]-t[j])%mod;
            }
            swap(b,c);
        }
        a=b;
    }
    vector<int> div(vector<int>a,vector<int>b) {
        tz(b);
        int n=a.size(),m=b.size();
        if(n<m)return a;
        reverse(a.begin(),a.end());
        reverse(b.begin(),b.end());
        inv(b,n-m+1);
        a.resize(n-m+1);
        a=conv(a,b);
        a.resize(n-m+1);
        reverse(a.begin(),a.end());
        return a;
    }
    vector<int> rem(vector<int>&a,vector<int>&b,vector<int>&d) {
        vector<int>t=conv(b,d);
        vector<int>res=minus(a,t);
        tz(res);
        return res;
    }
    vector<int> Mod(vector<int>a,vector<int>b) {
        vector<int>t=div(a,b);
        return rem(a,b,t);
    }
    vector<int>ny;
    inline void csh(int n) {
        ny.resize(n+1);
        ny[1]=1;
        for(int i=2; i<=n; i++) {
            ny[i]=mod-(mod/i)*(long long)ny[mod%i]%mod;
        }
    }
    vector<int> ln(vector<int>a) {
        int n=a.size();
        vector<int>b;
        b.resize(n-1);
        for(int i=0; i<n-1; i++) {
            b[i]=a[i+1]*(i+1ll)%mod;
        }
        inv(a,n-1);
        a=conv(a,b);
        a.resize(n);
        csh(n);
        for(int i=n-1; i>=1; i--) {
            a[i]=a[i-1]*(long long)ny[i]%mod;
        }
        a[0]=0;
        return a;
    }
    vector<vector<int>>vl;
    int val_one(int x,vector<int>&a) {
        int res=0,cur=1;
        for(int i=0; i<a.size(); i++) {
            res=(res+a[i]*(long long)cur)%mod;
            cur=cur*(long long)x%mod;
        }
        return res;
    }
    void val1(vector<int>&x,int wz) {
        if(x.size()==1) {
            vl[wz].resize(2);
            vl[wz][1]=1;
            vl[wz][0]=-x[0];
            return;
        }
        int mid=x.size()>>1,lt=wz<<1,rt=wz<<1|1;
        vector<int>l,r;
        l.resize(mid);
        r.resize(x.size()-mid);
        for(int i=0; i<mid; i++)l[i]=x[i];
        for(int i=mid; i<x.size(); i++)r[i-mid]=x[i];
        val1(l,lt);
        val1(r,rt);
        vl[wz]=conv(vl[lt],vl[rt]);
        tz(vl[wz]);
    }
    vector<int> val2(vector<int>&x,vector<int>&a,int wz) {
        if(x.size()==1) {
            return {val_one(x[0],a)};
        }
        int mid=x.size()>>1,lt=wz<<1,rt=wz<<1|1;
        vector<int>l,r,t=Mod(a,vl[wz]);
        l.resize(mid);
        r.resize(x.size()-mid);
        for(int i=0; i<mid; i++)l[i]=x[i];
        for(int i=mid; i<x.size(); i++)r[i-mid]=x[i];
        vector<int>res=val2(l,t,lt);
        vector<int>res1=val2(r,t,rt);
        int t1=res.size();
        res.resize(res.size()+res1.size());
        for(int i=0; i<res1.size(); i++) {
            res[t1+i]=res1[i];
        }
        return res;
    }
    vector<int>val(vector<int>a,vector<int>x) {
        vl.resize(x.size()<<2+1);
        val1(x,1);
        return val2(x,a,1);
    }
};

功能可以自己琢磨。

评论

xuanxuan0604
好评,下次出个FFT

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。