diff --git a/string/sa.cc b/string/sa.cc index ee477b0..9b79543 100644 --- a/string/sa.cc +++ b/string/sa.cc @@ -1,52 +1,98 @@ -#include "../include.hh" +struct SA { + int n; + std::vector sa, rk, lc; + vector> st; -constexpr int N = 1e6 + 10; - -char s[N]; -int n, sa[N], rk[N], oldrk[N << 1], id[N], key1[N], cnt[N], height[N]; - -bool cmp(int x, int y, int w) { - return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w]; -} - -void calc_sa() { - n = strlen(s + 1); - int i, m = 127, p, w; - for (i = 1; i <= n; ++i) ++cnt[rk[i] = s[i]]; - for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; - for (i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i; - - for (w = 1;; w <<= 1, m = p) { - for (p = 0, i = n; i > n - w; --i) id[++p] = i; - for (i = 1; i <= n; ++i) - if (sa[i] > w) id[++p] = sa[i] - w; - - memset(cnt, 0, sizeof(cnt)); - for (i = 1; i <= n; ++i) ++cnt[key1[i] = rk[id[i]]]; - - for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; - for (i = n; i >= 1; --i) sa[cnt[key1[i]]--] = id[i]; - memcpy(oldrk + 1, rk + 1, n * sizeof(int)); - for (p = 0, i = 1; i <= n; ++i) - rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p; - if (p == n) { - break; + SA(std::string s) { + n = s.size(); + sa.resize(n); + lc.resize(n - 1); + rk.resize(n); + std::iota(sa.begin(), sa.end(), 0); + sort_by_key(sa.begin(), sa.end(), expr(s[i], int i)); + rk[sa[0]] = 0; + for (int i = 1; i < n; i++) { + rk[sa[i]] = rk[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]); } - } -} + int k = 1; + std::vector tmp, cnt(n); + tmp.reserve(n); + while (rk[sa[n - 1]] < n - 1) { + tmp.clear(); + for (int i = 0; i < k; i++) { + tmp.push_back(n - k + i); + } + for (auto i : sa) { + if (i >= k) { + tmp.push_back(i - k); + } + } + std::fill(cnt.begin(), cnt.end(), 0); + for (int i = 0; i < n; i++) { + cnt[rk[i]]++; + } + for (int i = 1; i < n; i++) { + cnt[i] += cnt[i - 1]; + } + for (int i = n - 1; i >= 0; i--) { + sa[--cnt[rk[tmp[i]]]] = tmp[i]; + } + std::swap(rk, tmp); + rk[sa[0]] = 0; + for (int i = 1; i < n; i++) { + rk[sa[i]] = rk[sa[i - 1]] + (tmp[sa[i - 1]] < tmp[sa[i]] || sa[i - 1] + k == n || tmp[sa[i - 1] + k] < tmp[sa[i] + k]); + } + k *= 2; + } + for (int i = 0, j = 0; i < n; i++) { + if (rk[i] == 0) { + j = 0; + } else { + for (j -= j > 0; i + j < n && sa[rk[i] - 1] + j < n && s[i + j] == s[sa[rk[i] - 1] + j]; ) { + j++; + } + lc[rk[i] - 1] = j; + } + } + int m = lc.size(); + int lgm = lg2(m); + st = vector(lgm + 1, vector(m)); + st[0] = lc; + for (int j = 0; j < lgm; j++) { + for (int i = 0; i + (2 << j) <= m; i++) { + st[j + 1][i] = std::min(st[j][i], st[j][i + (1 << j)]); + } + } + } -void calc_height() { - for (i = 1, k = 0; i <= n; ++i) { - if (rk[i] == 0) continue; - if (k) --k; - while (s[i + k] == s[sa[rk[i] - 1] + k]) ++k; - height[rk[i]] = k; - } -} + int rmq(int l, int r) { + int k = lg2(r - l); + return std::min(st[k][l], st[k][r - (1 << k)]); + } -int main() { - untie; - cin >> (s + 1); // array s starts from index 1 - calc_sa(); - for (int i = 1; i <= n; ++i) cout << sa[i] << " \n"[i == n]; -} + __attribute__((target("lzcnt"))) + int lcp(int i, int j) { + if (i == j || i == n || j == n) { + return std::min(n - i, n - j); + } + int a = rk[i]; + int b = rk[j]; + if (a > b) { + std::swap(a, b); + } + deb(a, b, rmq(a, b)); + return std::min({n - i, n - j, rmq(a, b)}); + } + + int lcs(int i, int j) { + if (i == j || i == 0 || j == 0) { + return std::min(i, j); + } + int a = rk[n + n - i]; + int b = rk[n + n - j]; + if (a > b) { + std::swap(a, b); + } + return std::min({i, j, rmq(a, b)}); + } +};