From a831bd62780d6d58963c16d34f1f283f560307c8 Mon Sep 17 00:00:00 2001 From: arielherself Date: Thu, 26 Dec 2024 02:52:22 +0800 Subject: [PATCH] fix MLL multiplication overflow --- template.cc | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/template.cc b/template.cc index 6ac2531..aeb39b7 100644 --- a/template.cc +++ b/template.cc @@ -48,6 +48,7 @@ constexpr int INF = 0x3f3f3f3f; constexpr ll INFLL = 0x3f3f3f3f3f3f3f3fLL; constexpr ll MDL = 1e9 + 7; constexpr ll PRIME = 998'244'353; +constexpr ll PRIMELL = 901017227882342239LL; constexpr ll MDL1 = 8784491; constexpr ll MDL2 = PRIME; constexpr int128 INT128_MAX = numeric_limits::max(); @@ -96,9 +97,9 @@ struct igt { /* conditions */ #define loop while (1) -#define if_or(var, val) if (!(var == val)) var = val; else -#define continue_or(var, val) __AS_PROCEDURE(if (var == val) continue; var = val;) -#define break_or(var, val) __AS_PROCEDURE(if (var == val) break; var = val;) +#define if_or(var, val) if (!((var) == (val))) (var) = (val); else +#define continue_or(var, val) __AS_PROCEDURE(if ((var) == (val)) continue; (var) = (val);) +#define break_or(var, val) __AS_PROCEDURE(if ((var) == (val)) break; (var) = (val);) /* hash */ struct safe_hash { @@ -241,14 +242,14 @@ std::ostream& operator<<(std::ostream& dest, const int128& value) { template void __read(T& x) { cin >> x; } template void __read(T& x, U&... args) { cin >> x; __read(args...); } #define read(t, ...) __AS_PROCEDURE(argument_type::type __VA_ARGS__; __read(__VA_ARGS__);) -#define readvec(t, a, n) __AS_PROCEDURE(vector::type> a(n); for (auto& x : a) cin >> x;) -#define readvec1(t, a, n) __AS_PROCEDURE(vector::type> a((n) + 1); copy_n(ii::type>(cin), (n), a.begin() + 1);) -#define putvec(a) __AS_PROCEDURE(copy(a.begin(), a.end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;) -#define putvec1(a) __AS_PROCEDURE(copy(a.begin() + 1, a.end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;) -#define putvec_eol(a) __AS_PROCEDURE(copy(a.begin(), a.end(), oi<__as_typeof(a)::value_type>(cout, "\n"));) -#define putvec1_eol(a) __AS_PROCEDURE(copy(a.begin() + 1, a.end(), oi<__as_typeof(a)::value_type>(cout, "\n"));) +#define readvec(t, a, n) __AS_PROCEDURE(vector::type> a(n); for (auto& x : (a)) cin >> x;) +#define readvec1(t, a, n) __AS_PROCEDURE(vector::type> a((n) + 1); copy_n(ii::type>(cin), (n), (a).begin() + 1);) +#define putvec(a) __AS_PROCEDURE(copy((a).begin(), (a).end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;) +#define putvec1(a) __AS_PROCEDURE(copy((a).begin() + 1, (a).end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;) +#define putvec_eol(a) __AS_PROCEDURE(copy((a).begin(), (a).end(), oi<__as_typeof(a)::value_type>(cout, "\n"));) +#define putvec1_eol(a) __AS_PROCEDURE(copy((a).begin() + 1, (a).end(), oi<__as_typeof(a)::value_type>(cout, "\n"));) #define debug(x) __AS_PROCEDURE(cerr << #x" = " << (x) << endl;) -#define debugvec(a) __AS_PROCEDURE(cerr << #a" = "; for (auto&& x : a) cerr << x << ' '; cerr << endl;) +#define debugvec(a) __AS_PROCEDURE(cerr << #a" = "; for (auto&& x : (a)) cerr << x << ' '; cerr << endl;) #define deb(...) debug(make_tuple(__VA_ARGS__)) /* pops */ @@ -295,7 +296,7 @@ ll qpow(ll b, ll p, ll mod) { #pragma GCC diagnostic ignored "-Wparentheses" // Accurately find `i` 'th root of `n` (taking the floor) inline ll root(ll n, ll i) { - ll l = 0, r = pow(LLONG_MAX, ld(1) / i); + ll l = 0, r = pow(LLONG_MAX, (long double)(1) / i); while (l < r) { ll mid = l + r + 1 >> 1; if (qpow(mid, i) <= n) { @@ -395,7 +396,7 @@ vector calc_z(string t) { // z function of t } return z; } -vector kmp(string s, string t) { // find all t in s +vector kmp(const string& s, const string& t) { // find all t in s string cur = t + '#' + s; int sz1 = s.size(), sz2 = t.size(); vector v; @@ -405,7 +406,7 @@ vector kmp(string s, string t) { // find all t in s } return v; } -int period(string s) { // find the length of shortest recurring period +int period(const string& s) { // find the length of shortest recurring period int n = s.length(); auto z = calc_z(s); for (int i = 1; i <= n / 2; ++i) { @@ -423,8 +424,8 @@ template struct MLL { MLL(const MLL& other) : val(other.val) {} friend MLL operator+(const MLL& lhs, const MLL& rhs) { return mod(lhs.val + rhs.val, mdl); } friend MLL operator-(const MLL& lhs, const MLL& rhs) { return mod(lhs.val - rhs.val, mdl); } - friend MLL operator*(const MLL& lhs, const MLL& rhs) { return mod(lhs.val * rhs.val, mdl); } - friend MLL operator/(const MLL& lhs, const MLL& rhs) { return mod(lhs.val * mod(inverse(rhs.val, mdl), mdl), mdl); } + friend MLL operator*(const MLL& lhs, const MLL& rhs) { return mod(int128(lhs.val) * rhs.val, mdl); } + friend MLL operator/(const MLL& lhs, const MLL& rhs) { return lhs * mod(inverse(rhs.val, mdl), mdl); } friend MLL operator%(const MLL& lhs, const MLL& rhs) { return mod(lhs.val - (lhs / rhs).val, mdl); } friend bool operator==(const MLL& lhs, const MLL& rhs) { return lhs.val == rhs.val; } friend bool operator!=(const MLL& lhs, const MLL& rhs) { return lhs.val != rhs.val; }