川大2021校赛K题题解
作者:互联网
[没传送门]
简化题面:有一个长度为\(n\)的串\(S\),\(m\)个询问,每次询问\(k\)个以\(p_i\)开始的后缀\(suf_i \ (suf_i=S[p_i,n],i \in [1,k])\),求这些后缀两两之间的最长公共前缀的长度的和。(\(1\leqslant n \leqslant 5 \times 10^5,1 \leqslant m \leqslant 2000, \sum k \leqslant 5 \times 10^5\))
这道题算是一道比较板儿的科技题,会相应的算法,这题就能做出来了。
首先看这种前后缀的公共部分,基本能想到SAM,我们只要把串线翻转一下,就变成SAM能做的不同前缀的公共后缀了。
而不同前缀的公共后缀就对应着后缀链接树上这两个节点的lca。所以问题演化为:求\(k\)个点两两之间的lca的权值和。
这个用树形dp可以在单次\(O(n)\)的时间复杂度下解决。
具体来说,就是动态的维护每一个节点子树中被标记的节点的个数。考虑节点\(u\)能成为多少个节点的lca:对于\(u\)的新的一棵子树\(v\),他对答案的贡献就是\(v\)之前\(u\)的子树和\(siz[u]\)乘以\(siz[v]\)(注意,这个\(siz\)并不是子树大小,而是子树中被标记的点的个数)。
因为lca为\(u\)的子树中的点会在子树中被统计,lca为\(u\)的祖先的点会在\(u\)的祖先中统计,所以就能刚好不重不漏的统计出每个点成为lca的情况数了。
上述算法的单次查询复杂度为\(O(n)\),而总复杂度\(O(nm)\)显然是不能接受的,观察到\(\sum k \leqslant 5 \times 10^5\),就启发我们用虚树来优化上述的dp。
虚树能做到在多次询问的情况下,复杂度和询问的点数之和有关,所以复杂度降低成\(O(k\log n)\)(求lca用\(O(\log n)\)实现的)
比赛的时候卡在虚树的前一步,因为这东西早忘光了。
#include#include#include#include#include#includeusing namespace std; #define enter puts("") #define space putchar(' ') #define Mem(a, x) memset(a, x, sizeof(a)) #define In inline #define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt) typedef long long ll; typedef double db; const int INF = 0x3f3f3f3f; const db eps = 1e-8; const int maxn = 5e5 + 5; const int max2 = 1e6 + 5; const int maxs = 27; const int N = 20; In ll read() { ll ans = 0; char ch = getchar(), las = ' '; while(!isdigit(ch)) las = ch, ch = getchar(); while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar(); if(las == '-') ans = -ans; return ans; } In void write(ll x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar(x % 10 + '0'); } In void MYFILE() { #ifndef mrclr freopen("ha.in", "r", stdin); freopen("ha.out", "w", stdout); #endif } char s[maxn]; int n, m; int a[max2], id[maxn]; struct Sam { int tra[max2][maxs], link[max2], len[max2], cnt, las; In void init() {link[cnt = las = 0] = -1; Mem(tra[0], 0);} In void insert(int c, int x) { int now = ++cnt, p = las; Mem(tra[now], 0); len[now] = len[p] + 1, id[x] = now; while(~p && !tra[p][c]) tra[p][c] = now, p = link[p]; if(p == -1) link[now] = 0; else { int q = tra[p][c]; if(len[q] == len[p] + 1) link[now] = q; else { int clo = ++cnt; memcpy(tra[clo], tra[q], sizeof(tra[q])); len[clo] = len[p] + 1; link[clo] = link[q]; link[q] = link[now] = clo; while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p]; } } las = now; } }S; struct Edge { int nxt, to; }e[max2]; int head[max2], ecnt = -1; In void addEdge(int x, int y) { e[++ecnt] = (Edge){head[x], y}; head[x] = ecnt; } int dep[max2], fa[N + 2][max2], dfsx[max2], cnt = 0;; In void dfs(int now, int _f) { dfsx[now] = ++cnt; for(int i = 1; (1 << i) <= dep[now]; ++i) fa[i][now] = fa[i - 1][fa[i - 1][now]]; forE(i, now, v) { if(v == _f) continue; dep[v] = dep[now] + 1; fa[0][v] = now; dfs(v, now); } } In int lca(int x, int y) { if(dep[x] < dep[y]) swap(x, y); for(int i = N; i >= 0; --i) if(dep[x] - (1 << i) >= dep[y]) x = fa[i][x]; if(x == y) return x; for(int i = N; i >= 0; --i) if(fa[i][x] ^ fa[i][y]) x = fa[i][x], y = fa[i][y]; return fa[0][x]; } int vir[max2]; int st[max2], top = 0; int pa[max2]; bool cmp(const int& x, const int& y) {return dfsx[x] < dfsx[y];} In void build(int& m) { vir[++m] = 1; top = 0; sort(vir + 1, vir + m + 1, cmp); int tp = m; for(int i = 1; i dep[Lca]) { if(dep[st[top - 1]] < dep[Lca]) pa[st[top]] = Lca; --top; } if(Lca != st[top]) { vir[++m] = Lca; pa[Lca] = st[top]; st[++top] = Lca; } pa[st[++top] = now] = Lca; } sort(vir + 2, vir + m + 1, cmp); } bool vis[max2]; int siz[max2]; In ll DP(int m) { ll ret = 0; for(int i = m; i > 1; --i) { int now = vir[i], fa = pa[now]; ret += 1LL * siz[fa] * siz[now] * a[fa]; siz[fa] += siz[now]; } return ret; } In int solve_init() { int K = read(), ret = 0; for(int i = 1; i <= K; ++i) { int x = id[n - read()] + 1; if(!vis[x]) vir[++ret] = x, siz[x] = vis[x] = 1; } return ret; } int main() { // MYFILE(); Mem(head, -1), ecnt = -1; n = read(), m = read(); scanf("%s", s); reverse(s, s + n); S.init(); for(int i = 0; i < n; ++i) S.insert(s[i] - 'a', i); for(int i = 0; i <= S.cnt; ++i) addEdge(S.link[i] + 1, i + 1), a[i + 1] = S.len[i]; dep[1] = 1, dfs(1, 0); for(int i = 1; i <= m; ++i) { int K = solve_init(); build(K); write(DP(K)), enter; for(int j = 1; j <= K; ++j) siz[vir[j]] = vis[vir[j]] = 0; } return 0; }
标签:fa,int,题解,max2,川大,link,tra,校赛,now 来源: https://blog.51cto.com/u_15234622/2831635