题目
分类讨论下换根就完事了
用线段树维护dfs序
也可以用树链剖分
#include <bits/stdc++.h>
#define il inline
#define Max 5000005
#define ls(x) x<<1
#define rs(x) x<<1|1
#define ll long long
#define getchar() (tt==ss&&(tt=(ss=In)+fread(In,1,Max,stdin),ss==tt)?EOF:*ss++)
using namespace std;
char In[Max],*ss=In,*tt=In;
il int read()
{
char c=getchar();
int x=0,f=1;
while(c>'9'||c<'0')
{
if(c=='-') f=-1;
c=getchar();
}
while(c>='0'&&c<='9')
{
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
struct node
{
int t,nt,w;
}e[Max<<1];
int head[Max],w[Max],tot,n,m,cnt,rt,d[Max],dfn[Max],f[Max][30],sz[Max],rk[Max];
ll t[Max<<2],tg[Max<<2];
il void add(int u,int v)
{
e[++tot].t=v;
e[tot].nt=head[u];
head[u]=tot;
}
il void dfs(int u,int fa)
{
d[u]=d[fa]+1;
sz[u]=1;
dfn[u]=++cnt;
rk[cnt]=u;
f[u][0]=fa;
for(int i=1;i<=25;i++)
f[u][i]=f[f[u][i-1]][i-1];
for(int i=head[u];i;i=e[i].nt)
{
int v=e[i].t;
if(v==fa) continue;
dfs(v,u);
sz[u]+=sz[v];
}
}
il int lca(int u,int v)
{
if(d[u]<d[v]) swap(u,v);
for(int i=25;i>=0;i--)
if(d[f[u][i]]>=d[v]) u=f[u][i];
if(u==v) return u;
for(int i=25;i>=0;i--)
{
if(f[u][i]!=f[v][i])
u=f[u][i],v=f[v][i];
}
return f[u][0];
}
il void pushup(int x)
{
t[x]=t[ls(x)]+t[rs(x)];
}
il void wk(int x,int l,int r,ll k)
{
t[x]+=1ll*(r-l+1)*k;
tg[x]+=k;
}
il void pushdown(int x,int l,int r)
{
if(!tg[x]) return;
int mid=(l+r)>>1;
wk(ls(x),l,mid,tg[x]);
wk(rs(x),mid+1,r,tg[x]);
tg[x]=0;
}
il void build(int x,int l,int r)
{
if(l==r)
{
t[x]=w[rk[l]];
return;
}
int mid=(l+r)>>1;
build(ls(x),l,mid);
build(rs(x),mid+1,r);
pushup(x);
}
il void mdf(int x,int l,int r,int ql,int qr,ll k)
{
if(ql<=l&&r<=qr)
{
t[x]+=1ll*(r-l+1)*k;
tg[x]+=k;
return;
}
int mid=(l+r)>>1;
pushdown(x,l,r);
if(ql<=mid) mdf(ls(x),l,mid,ql,qr,k);
if(qr>mid) mdf(rs(x),mid+1,r,ql,qr,k);
pushup(x);
}
il ll qry(int x,int l,int r,int ql,int qr)
{
if(ql<=l&&r<=qr)
{
return t[x];
}
ll mid=(l+r)>>1,res=0;
pushdown(x,l,r);
if(ql<=mid) res+=qry(ls(x),l,mid,ql,qr);
if(qr>mid) res+=qry(rs(x),mid+1,r,ql,qr);
pushup(x);
return res;
}
il int find(int u,int d)
{
for(int i=25;i>=0;i--) if(d&(1<<i)) u=f[u][i];
return u;
}
il int get(int u,int f)
{
return find(u,d[u]-d[f]-1);
}
int q1[Max],q2[Max];
signed main()
{
n=read(),m=read();
for(int i=1;i<=n;i++) w[i]=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
rt=1;
dfs(1,0);
build(1,1,n);
while(m--)
{
int opt=read();
if(opt==1)
{
rt=read();
}
else if(opt==2)
{
int u=read(),v=read(),w=read();
int x=lca(u,v);
if(dfn[x]<=dfn[rt]&&dfn[rt]<=dfn[x]+sz[x]-1)
{
int r1=lca(u,rt),r2=lca(v,rt);
mdf(1,1,n,1,n,w);
if(r1!=rt&&r2!=rt)
{
if(d[r1]<d[r2]) swap(r1,r2);
x=get(rt,r1);
mdf(1,1,n,dfn[x],dfn[x]+sz[x]-1,-w);
}
}
else mdf(1,1,n,dfn[x],dfn[x]+sz[x]-1,w);
}
else
{
int u=read();
ll ans=0;
if(dfn[u]<=dfn[rt]&&dfn[rt]<=dfn[u]+sz[u]-1)
{
ans+=qry(1,1,n,1,n);
if(u!=rt)
{
int x=get(rt,u);
ans-=qry(1,1,n,dfn[x],dfn[x]+sz[x]-1);
}
}
else ans=qry(1,1,n,dfn[u],dfn[u]+sz[u]-1);
printf("%lld\n",ans);
}
}
}
最后一次更新于2021-09-28 02:20:10
0 条评论