1
0
Fork 0

fix MLL multiplication overflow

This commit is contained in:
arielherself 2024-12-26 02:52:22 +08:00
parent 4a6712e156
commit a831bd6278
Signed by: arielherself
SSH Key Fingerprint: SHA256:AK3cyo9tFsp7Mox7K0sYphleC8hReXhnRKxwuDT5LBc
1 changed files with 16 additions and 15 deletions

View File

@ -48,6 +48,7 @@ constexpr int INF = 0x3f3f3f3f;
constexpr ll INFLL = 0x3f3f3f3f3f3f3f3fLL; constexpr ll INFLL = 0x3f3f3f3f3f3f3f3fLL;
constexpr ll MDL = 1e9 + 7; constexpr ll MDL = 1e9 + 7;
constexpr ll PRIME = 998'244'353; constexpr ll PRIME = 998'244'353;
constexpr ll PRIMELL = 901017227882342239LL;
constexpr ll MDL1 = 8784491; constexpr ll MDL1 = 8784491;
constexpr ll MDL2 = PRIME; constexpr ll MDL2 = PRIME;
constexpr int128 INT128_MAX = numeric_limits<int128>::max(); constexpr int128 INT128_MAX = numeric_limits<int128>::max();
@ -96,9 +97,9 @@ struct igt {
/* conditions */ /* conditions */
#define loop while (1) #define loop while (1)
#define if_or(var, val) if (!(var == val)) var = val; else #define if_or(var, val) if (!((var) == (val))) (var) = (val); else
#define continue_or(var, val) __AS_PROCEDURE(if (var == val) continue; var = val;) #define continue_or(var, val) __AS_PROCEDURE(if ((var) == (val)) continue; (var) = (val);)
#define break_or(var, val) __AS_PROCEDURE(if (var == val) break; var = val;) #define break_or(var, val) __AS_PROCEDURE(if ((var) == (val)) break; (var) = (val);)
/* hash */ /* hash */
struct safe_hash { struct safe_hash {
@ -241,14 +242,14 @@ std::ostream& operator<<(std::ostream& dest, const int128& value) {
template<typename T> void __read(T& x) { cin >> x; } template<typename T> void __read(T& x) { cin >> x; }
template<typename T, typename... U> void __read(T& x, U&... args) { cin >> x; __read(args...); } template<typename T, typename... U> void __read(T& x, U&... args) { cin >> x; __read(args...); }
#define read(t, ...) __AS_PROCEDURE(argument_type<void(t)>::type __VA_ARGS__; __read(__VA_ARGS__);) #define read(t, ...) __AS_PROCEDURE(argument_type<void(t)>::type __VA_ARGS__; __read(__VA_ARGS__);)
#define readvec(t, a, n) __AS_PROCEDURE(vector<argument_type<void(t)>::type> a(n); for (auto& x : a) cin >> x;) #define readvec(t, a, n) __AS_PROCEDURE(vector<argument_type<void(t)>::type> a(n); for (auto& x : (a)) cin >> x;)
#define readvec1(t, a, n) __AS_PROCEDURE(vector<argument_type<void(t)>::type> a((n) + 1); copy_n(ii<argument_type<void(t)>::type>(cin), (n), a.begin() + 1);) #define readvec1(t, a, n) __AS_PROCEDURE(vector<argument_type<void(t)>::type> a((n) + 1); copy_n(ii<argument_type<void(t)>::type>(cin), (n), (a).begin() + 1);)
#define putvec(a) __AS_PROCEDURE(copy(a.begin(), a.end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;) #define putvec(a) __AS_PROCEDURE(copy((a).begin(), (a).end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;)
#define putvec1(a) __AS_PROCEDURE(copy(a.begin() + 1, a.end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;) #define putvec1(a) __AS_PROCEDURE(copy((a).begin() + 1, (a).end(), oi<__as_typeof(a)::value_type>(cout, " ")); cout << endl;)
#define putvec_eol(a) __AS_PROCEDURE(copy(a.begin(), a.end(), oi<__as_typeof(a)::value_type>(cout, "\n"));) #define putvec_eol(a) __AS_PROCEDURE(copy((a).begin(), (a).end(), oi<__as_typeof(a)::value_type>(cout, "\n"));)
#define putvec1_eol(a) __AS_PROCEDURE(copy(a.begin() + 1, a.end(), oi<__as_typeof(a)::value_type>(cout, "\n"));) #define putvec1_eol(a) __AS_PROCEDURE(copy((a).begin() + 1, (a).end(), oi<__as_typeof(a)::value_type>(cout, "\n"));)
#define debug(x) __AS_PROCEDURE(cerr << #x" = " << (x) << endl;) #define debug(x) __AS_PROCEDURE(cerr << #x" = " << (x) << endl;)
#define debugvec(a) __AS_PROCEDURE(cerr << #a" = "; for (auto&& x : a) cerr << x << ' '; cerr << endl;) #define debugvec(a) __AS_PROCEDURE(cerr << #a" = "; for (auto&& x : (a)) cerr << x << ' '; cerr << endl;)
#define deb(...) debug(make_tuple(__VA_ARGS__)) #define deb(...) debug(make_tuple(__VA_ARGS__))
/* pops */ /* pops */
@ -295,7 +296,7 @@ ll qpow(ll b, ll p, ll mod) {
#pragma GCC diagnostic ignored "-Wparentheses" #pragma GCC diagnostic ignored "-Wparentheses"
// Accurately find `i` 'th root of `n` (taking the floor) // Accurately find `i` 'th root of `n` (taking the floor)
inline ll root(ll n, ll i) { inline ll root(ll n, ll i) {
ll l = 0, r = pow(LLONG_MAX, ld(1) / i); ll l = 0, r = pow(LLONG_MAX, (long double)(1) / i);
while (l < r) { while (l < r) {
ll mid = l + r + 1 >> 1; ll mid = l + r + 1 >> 1;
if (qpow<int128>(mid, i) <= n) { if (qpow<int128>(mid, i) <= n) {
@ -395,7 +396,7 @@ vector<int> calc_z(string t) { // z function of t
} }
return z; return z;
} }
vector<int> kmp(string s, string t) { // find all t in s vector<int> kmp(const string& s, const string& t) { // find all t in s
string cur = t + '#' + s; string cur = t + '#' + s;
int sz1 = s.size(), sz2 = t.size(); int sz1 = s.size(), sz2 = t.size();
vector<int> v; vector<int> v;
@ -405,7 +406,7 @@ vector<int> kmp(string s, string t) { // find all t in s
} }
return v; return v;
} }
int period(string s) { // find the length of shortest recurring period int period(const string& s) { // find the length of shortest recurring period
int n = s.length(); int n = s.length();
auto z = calc_z(s); auto z = calc_z(s);
for (int i = 1; i <= n / 2; ++i) { for (int i = 1; i <= n / 2; ++i) {
@ -423,8 +424,8 @@ template <ll mdl> struct MLL {
MLL(const MLL<mdl>& other) : val(other.val) {} MLL(const MLL<mdl>& other) : val(other.val) {}
friend MLL operator+(const MLL& lhs, const MLL& rhs) { return mod(lhs.val + rhs.val, mdl); } friend MLL operator+(const MLL& lhs, const MLL& rhs) { return mod(lhs.val + rhs.val, mdl); }
friend MLL operator-(const MLL& lhs, const MLL& rhs) { return mod(lhs.val - rhs.val, mdl); } friend MLL operator-(const MLL& lhs, const MLL& rhs) { return mod(lhs.val - rhs.val, mdl); }
friend MLL operator*(const MLL& lhs, const MLL& rhs) { return mod(lhs.val * rhs.val, mdl); } friend MLL operator*(const MLL& lhs, const MLL& rhs) { return mod(int128(lhs.val) * rhs.val, mdl); }
friend MLL operator/(const MLL& lhs, const MLL& rhs) { return mod(lhs.val * mod(inverse(rhs.val, mdl), mdl), mdl); } friend MLL operator/(const MLL& lhs, const MLL& rhs) { return lhs * mod(inverse(rhs.val, mdl), mdl); }
friend MLL operator%(const MLL& lhs, const MLL& rhs) { return mod(lhs.val - (lhs / rhs).val, mdl); } friend MLL operator%(const MLL& lhs, const MLL& rhs) { return mod(lhs.val - (lhs / rhs).val, mdl); }
friend bool operator==(const MLL& lhs, const MLL& rhs) { return lhs.val == rhs.val; } friend bool operator==(const MLL& lhs, const MLL& rhs) { return lhs.val == rhs.val; }
friend bool operator!=(const MLL& lhs, const MLL& rhs) { return lhs.val != rhs.val; } friend bool operator!=(const MLL& lhs, const MLL& rhs) { return lhs.val != rhs.val; }