题目大意
一个长度为$n$的字符串$S$,令$T_i$表示它从第$i$个字符开始的后缀。
求:
其中,$len(a)$表示字符串$a$的长度,$lcp(a,b)$表示字符串$a$和字符串$b$的最长公共前缀。
题目分析
$\sum_{1\le i\lt j\le n}len(T_i)+len(T_j)=(n-1)\sum_{i=1}^nlen(T_i)=\frac{(n-1)n(n+1)}2$
考虑求出后半部分:
即所有后缀两两的最长公共前缀的长度之和。
将原串翻转,那么我们要求的是:所有前缀的两两最长公共后缀之和。
因为后缀自动机有性质:两个串的最长公共后缀,维护对应的状态在前缀树上的LCA状态
将所有前缀对应结点染色,因此我们要求对于每个结点是多少黑色结点的LCA。
这个问题可以按逆拓扑序从下至上做一次递推的统计解决。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
| #include<algorithm> #include<iostream> #include<iomanip> #include<cstring> #include<cstdlib> #include<climits> #include<vector> #include<cstdio> #include<cmath> #include<queue> using namespace std;
typedef long long LL;
inline const int Get_Int() { int num=0,bj=1; char x=getchar(); while(x<'0'||x>'9') { if(x=='-')bj=-1; x=getchar(); } while(x>='0'&&x<='9') { num=num*10+x-'0'; x=getchar(); } return num*bj; }
const int maxn=500005,maxc=26;
int n;
struct SuffixAutomaton { int cnt,root,last; int next[maxn*2],Max[maxn*2],end_pos[maxn*2],Bucket[maxn*2],top[maxn*2]; int child[maxn*2][maxc]; SuffixAutomaton() { cnt=0; root=last=newnode(0); } int newnode(int val) { cnt++; next[cnt]=0; Max[cnt]=val; memset(child[cnt],0,sizeof(child[cnt])); return cnt; } void insert(int data) { int p=last,u=newnode(Max[last]+1); last=u; end_pos[u]=1; for(; p&&!child[p][data]; p=next[p])child[p][data]=u; if(!p)next[u]=root; else { int old=child[p][data]; if(Max[old]==Max[p]+1)next[u]=old; else { int New=newnode(Max[p]+1); copy(child[old],child[old]+maxc,child[New]); next[New]=next[old]; next[u]=next[old]=New; for(; child[p][data]==old; p=next[p])child[p][data]=New; } } } void build(string s) { for(auto x:s)insert(x-'a'); } void topsort() { for(int i=1; i<=cnt; i++)Bucket[Max[i]]++; for(int i=1; i<=cnt; i++)Bucket[i]+=Bucket[i-1]; for(int i=1; i<=cnt; i++)top[Bucket[Max[i]]--]=i; } LL dp() { LL ans=(LL)(n-1)*n*(n+1)/2; for(int i=cnt; i>=1; i--) { int now=top[i],fa=next[now]; ans-=2ll*end_pos[now]*end_pos[fa]*Max[fa]; if(fa)end_pos[fa]+=end_pos[now]; } return ans; } } sam;
char s[maxn];
int main() { scanf("%s",s); n=strlen(s); reverse(s,s+n); sam.build(s); sam.topsort(); printf("%lld\n",sam.dp()); return 0; }
|