【YBT2023寒假Day6 C】子串染色(SAM)(线段树)(启发式合并)
阿里云国内75折 回扣 微信号:monov8 |
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6 |
子串染色
题目链接YBT2023寒假Day6 C
题目大意
对于一个 s 的子串 t把它在 s 中所有出现的位置包含的所有下标染黑黑色连续段的数目是子串 t 的价值。
然后给你一个 k 和一个串 s求有多少个 s 的子串价值恰好为 k。
思路
那首先你如果要找到所有的子串SAM 是一个很好的方法。
于是我们考虑在 fail 树上处理这个问题。
不妨先考虑假如我们一个一个枚举子串要怎么判断是否可行。
那我们考虑先有子串在原串中出现了一些次数那右边的坐标我们记作
x
1
,
x
2
,
.
.
.
,
x
m
x_1,x_2,...,x_m
x1,x2,...,xm假设排序了。
那如果这个子串的长度是
k
k
k那对于满足
x
i
−
x
i
−
1
⩽
k
x_i-x_{i-1}\leqslant k
xi−xi−1⩽k 的数量我们记作
s
s
s那这个子串的价值就是
m
−
s
m-s
m−s。
就是缝起来的地方就相当于少了一个
那在 SAM 上一个点代表了一类子串而且长度是一个区间
[
l
,
r
]
[l,r]
[l,r]于是考虑加速或者维护上面的过程。
会发现只跟
x
i
−
x
i
−
1
x_i-x_{i-1}
xi−xi−1 这种东西有关那我们考虑用一个线段树维护
x
i
−
x
i
−
1
x_{i}-x_{i-1}
xi−xi−1 每个值的出现次数。
那我们找到第
m
−
k
m-k
m−k 小的数
a
a
a 和
m
−
k
+
1
m-k+1
m−k+1 小的数
b
b
b那
[
a
,
b
−
1
]
[a,b-1]
[a,b−1] 就是约束。
那配上我们有的区间就是
[
l
,
r
]
∩
[
a
,
b
−
1
]
[l,r]\cap [a,b-1]
[l,r]∩[a,b−1]
那考虑维护 x i − x i − 1 x_i-x_{i-1} xi−xi−1那因为这个出现的集合 x i x_i xi 是 fail 树上的子树那我们可以用 dsu on tree也可以用启发式合并用 set 维护 x i x_i xi用权值线段树维护 x i − x i − 1 x_i-x_{i-1} xi−xi−1至于查询 k k k 大可以直接二分虽然也可以在线段树上二分就是了
代码
#include<set>
#include<cstdio>
#include<vector>
#include<cstring>
#define ll long long
using namespace std;
const int N = 1e5 + 100;
char SS[N];
int n, k, dy[N << 1];
ll ans;
struct SAM {
int tot, lst;
struct node {
int len, fa, son[26];
}d[N << 1];
void Init() {
tot = lst = 1;
}
void insert(int x, int id) {
int p = lst, np = ++tot; lst = np;
d[np].len = d[p].len + 1; dy[np] = id;
for (; p && !d[p].son[x]; p = d[p].fa) d[p].son[x] = np;
if (!p) d[np].fa = 1;
else {
int q = d[p].son[x];
if (d[q].len == d[p].len + 1) d[np].fa = q;
else {
int nq = ++tot; d[nq] = d[q];
d[nq].len = d[p].len + 1;
d[q].fa = d[np].fa = nq;
for (; p && d[p].son[x] == q; p = d[p].fa) d[p].son[x] = nq;
}
}
}
}S;
struct XD_tree {
int f[N << 6], ls[N << 6], rs[N << 6], tot;
void update(int &now, int l, int r, int pl, int x) {
if (!now) now = ++tot;
f[now] += x;
if (l == r) return ;
int mid = (l + r) >> 1;
if (pl <= mid) update(ls[now], l, mid, pl, x);
else update(rs[now], mid + 1, r, pl, x);
}
int query(int now, int l, int r, int L, int R) {
if (!now) return 0;
if (L > R) return 0;
if (L <= l && r <= R) return f[now];
int mid = (l + r) >> 1, re = 0;
if (L <= mid) re += query(ls[now], l, mid, L, R);
if (mid < R) re += query(rs[now], mid + 1, r, L, R);
return re;
}
}T;
vector <int> G[N << 1];
set <int> s[N << 1];
int id[N << 1], rt[N << 1];
void merge(int &x, int &y) {
if (s[x].size() < s[y].size()) swap(x, y);
for (set <int> ::iterator it = s[y].begin(); it != s[y].end(); it++) {
int now = *it;
set <int> ::iterator pl = s[x].lower_bound(now);
int r = (pl == s[x].end()) ? 0 : *pl;
int l = (pl == s[x].begin()) ? 0 : *(--pl);
if (l && r) T.update(rt[x], 1, n, r - l, -1);
if (l) T.update(rt[x], 1, n, now - l, 1);
if (r) T.update(rt[x], 1, n, r - now, 1);
s[x].insert(now);
}
}
void dfs(int now) {
if (dy[now]) s[id[now]].insert(S.d[now].len);
for (int i = 0; i < G[now].size(); i++) {
dfs(G[now][i]);
merge(id[now], id[G[now][i]]);
}
if (now == 1) return ;
int m = s[id[now]].size();
int L = S.d[S.d[now].fa].len + 1, R = S.d[now].len, rel = L;
while (L <= R) {
int mid = (L + R) >> 1;
if (T.query(rt[id[now]], 1, n, 1, mid) >= m - k) rel = mid, R = mid - 1;
else L = mid + 1;
}
if (T.query(rt[id[now]], 1, n, 1, rel) != m - k) return ;
L = S.d[S.d[now].fa].len + 1; R = S.d[now].len; int rer = L;
while (L <= R) {
int mid = (L + R) >> 1;
if (T.query(rt[id[now]], 1, n, 1, mid) <= m - k) rer = mid, L = mid + 1;
else R = mid - 1;
}
ans += rer - rel + 1;
}
int main() {
freopen("gnirts.in", "r", stdin);
freopen("gnirts.out", "w", stdout);
scanf("%s", SS + 1); n = strlen(SS + 1);
scanf("%d", &k);
S.Init();
for (int i = 1; i <= n; i++) S.insert(SS[i] - 'a', i);
for (int i = 2; i <= S.tot; i++) G[S.d[i].fa].push_back(i);
for (int i = 1; i <= S.tot; i++) id[i] = i;
dfs(1);
printf("%lld", ans);
return 0;
}