题目

分类讨论下换根就完事了

用线段树维护dfs序

也可以用树链剖分

#include <bits/stdc++.h>
#define il inline
#define Max 5000005
#define ls(x) x<<1
#define rs(x) x<<1|1
#define ll long long
#define getchar() (tt==ss&&(tt=(ss=In)+fread(In,1,Max,stdin),ss==tt)?EOF:*ss++)
using namespace std;
char In[Max],*ss=In,*tt=In;
il int read()
{
    char c=getchar();
    int x=0,f=1;
    while(c>'9'||c<'0')
    {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9')
    {
        x=x*10+c-'0';
        c=getchar();
    }
    return x*f;
}
struct node
{
    int t,nt,w;
}e[Max<<1];
int head[Max],w[Max],tot,n,m,cnt,rt,d[Max],dfn[Max],f[Max][30],sz[Max],rk[Max];
ll t[Max<<2],tg[Max<<2];
il void add(int u,int v)
{
    e[++tot].t=v;
    e[tot].nt=head[u];
    head[u]=tot;
}
il void dfs(int u,int fa)
{
    d[u]=d[fa]+1;
    sz[u]=1;
    dfn[u]=++cnt;
    rk[cnt]=u;
    f[u][0]=fa;
    for(int i=1;i<=25;i++)
        f[u][i]=f[f[u][i-1]][i-1];
    for(int i=head[u];i;i=e[i].nt)
    {
        int v=e[i].t;
        if(v==fa) continue;
        dfs(v,u);
        sz[u]+=sz[v];
    }
}
il int lca(int u,int v)
{
    if(d[u]<d[v]) swap(u,v);
    for(int i=25;i>=0;i--)
        if(d[f[u][i]]>=d[v]) u=f[u][i];
    if(u==v) return u;
    for(int i=25;i>=0;i--)
    {
        if(f[u][i]!=f[v][i])
            u=f[u][i],v=f[v][i];
    }
    return f[u][0];
}
il void pushup(int x)
{
    t[x]=t[ls(x)]+t[rs(x)];
}
il void wk(int x,int l,int r,ll k)
{
    t[x]+=1ll*(r-l+1)*k;
    tg[x]+=k;
}
il void pushdown(int x,int l,int r)
{
    if(!tg[x]) return;
    int mid=(l+r)>>1;
    wk(ls(x),l,mid,tg[x]);
    wk(rs(x),mid+1,r,tg[x]);
    tg[x]=0;
}
il void build(int x,int l,int r)
{
    if(l==r)
    {
        t[x]=w[rk[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(ls(x),l,mid);
    build(rs(x),mid+1,r);
    pushup(x);
}
il void mdf(int x,int l,int r,int ql,int qr,ll k)
{
    if(ql<=l&&r<=qr)
    {
        t[x]+=1ll*(r-l+1)*k;
        tg[x]+=k;
        return;
    }
    int mid=(l+r)>>1;
    pushdown(x,l,r);
    if(ql<=mid) mdf(ls(x),l,mid,ql,qr,k);
    if(qr>mid) mdf(rs(x),mid+1,r,ql,qr,k);
    pushup(x);
}
il ll qry(int x,int l,int r,int ql,int qr)
{
    if(ql<=l&&r<=qr)
    {
        return t[x];
    }
    ll mid=(l+r)>>1,res=0;
    pushdown(x,l,r);
    if(ql<=mid) res+=qry(ls(x),l,mid,ql,qr);
    if(qr>mid) res+=qry(rs(x),mid+1,r,ql,qr);
    pushup(x);
    return res;
}
il int find(int u,int d)
{
    for(int i=25;i>=0;i--) if(d&(1<<i)) u=f[u][i];
    return u;
}
il int get(int u,int f)
{
    return find(u,d[u]-d[f]-1);
}
int q1[Max],q2[Max];
signed main()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++) w[i]=read();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    rt=1;
    dfs(1,0);
    build(1,1,n);
    while(m--)
    {
        int opt=read();
        if(opt==1)
        {
            rt=read();
        }
        else if(opt==2)
        {
            int u=read(),v=read(),w=read();
            int x=lca(u,v);
            if(dfn[x]<=dfn[rt]&&dfn[rt]<=dfn[x]+sz[x]-1)
            {
                int r1=lca(u,rt),r2=lca(v,rt);
                mdf(1,1,n,1,n,w);
                if(r1!=rt&&r2!=rt)
                {
                    if(d[r1]<d[r2]) swap(r1,r2);
                    x=get(rt,r1);
                    mdf(1,1,n,dfn[x],dfn[x]+sz[x]-1,-w);
                }
            }
            else mdf(1,1,n,dfn[x],dfn[x]+sz[x]-1,w);
        }
        else
        {
            int u=read();
            ll ans=0;
            if(dfn[u]<=dfn[rt]&&dfn[rt]<=dfn[u]+sz[u]-1)
            {
                ans+=qry(1,1,n,1,n);
                if(u!=rt)
                {
                    int x=get(rt,u);
                    ans-=qry(1,1,n,dfn[x],dfn[x]+sz[x]-1);
                }
            }
            else ans=qry(1,1,n,dfn[u],dfn[u]+sz[u]-1);
            printf("%lld\n",ans);
        }
    }
}