LOADING

加载过慢请开启缓存 浏览器默认开启

virtual-tree

2025/10/18 virtual tree

好久没更学习笔记了。

算法简介

通常情况下,如果是对树上单点的询问,我们是用不到树上的所有点的。

换句话说,如果我们树上能用到的节点非常少,那我们就可以避免遍历整棵树,从而降低复杂度。

下面用一个题说。

消耗战/【模板】虚树

先想朴素的 dp 做法:设 $f_i$ 表示 $i$ 不与子树中任意一个关键点连接的最小代价。

枚举 $i$ 的子树 $v$,则有转移方程:

  • 若 $v$ 不是关键点 $f_i=f_i+\min(f_v,w(i,v))$

  • 若 $v$ 是关键点 $f_i=f_i+w(i,v)$

非常简单对吧。

但是这个玩意是 $O(nq)$ 的,根本跑不动。

但是 $\sum k$ 比较小,所以我们考虑如果每次只遍历 $k$ 个点,复杂度就可以控制住了。

然后再观察转移式子,发现 $min(i,v)$ 这一项实际上是链上最小值,这个信息是轻松维护的,我们考虑建出虚树,对虚树 dp。

然后我们就可以考虑如何建虚树了。

一些定义

虚树上的节点我们称为关键点,关键点包含询问的节点和两两之间的 LCA。

方法一:两遍排序 LCA

先将关键点按 DFS 序排序后,两两求 LCA。

对选出来的点集去重后按 DFS 序排序,再两两求 LCA,并连边。

方法二:单调栈构建虚树

我们用单调栈维护虚树上的一条链,栈中的点在虚树上是相邻的,而且栈中自底向上 DFS 序递增。

然后我们就对不同的情况分讨即可。贴一个代码。

sort(h+1,h+1+k,cmp);
t=0;sta[++t]=h[1];
for(int i=2;i<=k;i++){
    int now=h[i],lca=LCA(now,sta[t]);
    while(1){
        if(dep[lca]>=dep[sta[t-1]]){
            if(lca!=sta[t]){
                M[lca].push_back(sta[t]);
                if(lca!=sta[t-1]) sta[t]=lca;
                else t--;
            }
            break;
        }else{
            M[sta[t-1]].push_back(sta[t]);
            t--;
        }
    }
    sta[++t]=now;
}
while(--t) M[sta[t]].push_back(sta[t+1]);

然后对着新生成的树正常跑 dp 即可。

撤销的时候按 DFS 序撤销,不要 memest 式清空,不然复杂度就退化了。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
constexpr int N = 2.5e5+10;
int dfn[N],dfx,siz[N],top[N],hson[N],fa[N],dep[N];
struct Edge{
    int v;
    ll w;
};
ll minv[N];
int n,m,k,h[N];
bool q[N];
vector<Edge> G[N];
vector<int> M[N];
void dfs1(int u,int f){
    fa[u]=f;siz[u]=1;dep[u]=dep[f]+1;
    for(auto to:G[u]){
        int v=to.v;
        if(v==f) continue;
        minv[v]=min(minv[u],to.w);
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>siz[hson[u]]) hson[u]=v;
    }
}
void dfs2(int u,int tp){
    top[u]=tp;dfn[u]=++dfx;
    if(!hson[u]) return ;
    dfs2(hson[u],tp);
    for(auto to:G[u]){
        int v=to.v;
        if(v==fa[u] || v==hson[u]) continue;
        dfs2(v,v);
    }
}
int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]<dep[y] ? x : y;
}
bool cmp(int x,int y){
    return dfn[x]<dfn[y];
}
int sta[N],t;
ll dfs3(int u){
    ll sum=0,tmp;
    for(int v:M[u]){
        sum+=dfs3(v);
    }
    if(q[u])tmp=minv[u];
    else tmp=min(sum,minv[u]);
    q[u]=false;
    M[u].clear();
    return tmp;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    minv[1]=1e18;
    cin>>n;
    for(int i=1,u,v,w;i<n;i++){
        cin>>u>>v>>w;
        G[u].push_back({v,w});G[v].push_back({u,w});
    }
    dfs1(1,0);dfs2(1,1);
    cin>>m;
    while(m--){
        cin>>k;
        for(int i=1;i<=k;i++){
            cin>>h[i];q[h[i]]=1;
        }
        sort(h+1,h+1+k,cmp);
        t=0;sta[++t]=h[1];
        for(int i=2;i<=k;i++){
            int now=h[i],lca=LCA(now,sta[t]);
            while(1){
                if(dep[lca]>=dep[sta[t-1]]){
                    if(lca!=sta[t]){
                        M[lca].push_back(sta[t]);
                        if(lca!=sta[t-1]) sta[t]=lca;
                        else t--;
                    }
                    break;
                }else{
                    M[sta[t-1]].push_back(sta[t]);
                    t--;
                }
            }
            sta[++t]=now;
        }
        while(--t) M[sta[t]].push_back(sta[t+1]);
        cout << dfs3(sta[1]) << '\n';
    }
    return 0;
}