因为 NOI 被虐傻了,蒟蒻的 YJQ 准备来学习一下字符串,于是它碰到了这样一道题:
给你一个长为 nnn 的字符串,求不同的子串的个数。
我们定义两个子串不同,当且仅当有这两个子串长度不一样或者长度一样且有任意一位不一样。
子串的定义:原字符串中连续的一段字符组成的字符串。
第一行一个整数 nnn。
接下来一行 nnn 个字符表示给出的字符串。
一行一个整数,表示不一样的子串个数。
5
aabaa
11
3
aba
5
请使用64位整数来进行输出。
对于 30%30\%30% 的数据,保证 n≤1000n\le 1000n≤1000。
对于 100%100\%100% 的数据,保证 1≤n≤1051 \leq n \le 10^51≤n≤105,字符串中只有小写英文字母。
1、子串就是后缀的前缀,所以可以枚举每个后缀,计算前缀总数,再减掉重复。
“前缀总数”其实就是子串个数,为 n * (n + 1) / 2
2、 如果按后缀排序的顺序枚举后缀,每次新增的子串就是除了与上一个后缀的 LCP 剩下的前缀,
只有这些前缀是新增的,因为 LCP 部分在枚举上一个前缀时计算过了,
所以答案为 n * (n + 1) / 2 - (height[1] + height[2] + … + height[n] )
#include
using namespace std;
typedef long long ll;
const int N = 1e6 + 10;char s[N];
int n, m; //n是后缀个数, m是桶的个数
int x[N]; //桶数组
int y[N]; //辅助数组
int c[N]; //计数数组
int sa[N]; //sa[k] 表示排名为k的数组后缀编号
int rk[N]; //rk[k] 表示后缀字符串k 的排名
int height[N]; // heght[k] = lcp(sa[i], sa[i - 1])void get_sa()
{int i, k;// 按第一个字母排序for(i = 1; i <= n; ++i) // 按第一个字母编桶号, 并累计c[(x[i] = s[i])]++;for(i = 1; i <= m; ++i) c[i] += c[i - 1];for(i = n; i; --i) //后缀i的排序是i 所在桶号的剩余累计值sa[c[x[i]]--] = i;for(k = 1; k <= n; k <<= 1) // logn 轮{// 按第二关键字排序memset(c, 0, sizeof c);for(i = 1; i <= n; ++i) y[i] = sa[i];for(i = 1; i <= n; ++i) c[x[y[i] + k]]++;for(i = 1; i <= m; ++i) c[i] += c[i - 1];for(i = n; i; i--) sa[c[x[y[i] + k]]--] = y[i];//按第一关键字排序memset(c, 0, sizeof c);for(i = 1; i <= n; ++i) y[i] = sa[i];for(i = 1; i <= n; ++i) c[x[y[i]]]++;for(i = 1; i <= m; ++i) c[i] += c[i - 1];for(i = n; i; --i) sa[c[x[y[i]]]--] = y[i];//把后缀放入桶数组for(i = 1; i <= n; ++i) y[i] = x[i];for(m = 0, i = 1; i <= n; ++i){if(y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k])x[sa[i]] = m;elsex[sa[i]] = ++m; // 相邻后缀的关键字不相等则放入新桶}if(m == n) break;}
}// 定理 height[rk[i]] >= height[rk[i - 1]] - 1;
void get_height()
{for(int i = 1; i <= n; ++i)rk[sa[i]] = i;for(int i = 1, k = 0; i <= n; ++i) //枚举后缀i{if(rk[i] == 1) continue; //第一名height 为0if(k) k--; //上一个后缀的height 值减 1int j = sa[rk[i] - 1]; //找出后缀i的前邻后缀 jwhile(i + k <= n && j + k <= n && s[i + k] == s[j + k])k++;height[rk[i]] = k;}
}int main()
{scanf("%d", &n);scanf("%s", s + 1);m = 128;get_sa();get_height();ll ans = (ll)n * (n + 1) / 2;for(int i = 2; i <= n; ++i)ans -= height[i];printf("%lld\n", ans);return 0;
}