diff --git a/include.hh b/include.hh index af4271c..cd29fd8 100644 --- a/include.hh +++ b/include.hh @@ -381,17 +381,42 @@ template struct MLL { void operator/=(const MLL& rhs) { val = (*this / rhs).val; } void operator%=(const MLL& rhs) { val = (*this % rhs).val; } }; +struct MLLd { + ll val, mdl; + MLLd(ll mdl, ll v = 0) : mdl(mdl), val(mod(v, mdl)) {} + MLLd(const MLLd& other) : val(other.val) {} + friend MLLd operator+(const MLLd& lhs, const MLLd& rhs) { return mod(lhs.val + rhs.val, mdl); } + friend MLLd operator-(const MLLd& lhs, const MLLd& rhs) { return mod(lhs.val - rhs.val, mdl); } + friend MLLd operator*(const MLLd& lhs, const MLLd& rhs) { return mod(lhs.val * rhs.val, mdl); } + friend MLLd operator/(const MLLd& lhs, const MLLd& rhs) { return mod(lhs.val * mod(inverse(rhs.val, mdl), mdl), mdl); } + friend MLLd operator%(const MLLd& lhs, const MLLd& rhs) { return mod(lhs.val - (lhs / rhs).val, mdl); } + friend bool operator==(const MLLd& lhs, const MLLd& rhs) { return lhs.val == rhs.val; } + friend bool operator!=(const MLLd& lhs, const MLLd& rhs) { return lhs.val != rhs.val; } + void operator+=(const MLLd& rhs) { val = (*this + rhs).val; } + void operator-=(const MLLd& rhs) { val = (*this - rhs).val; } + void operator*=(const MLLd& rhs) { val = (*this * rhs).val; } + void operator/=(const MLLd& rhs) { val = (*this / rhs).val; } + void operator%=(const MLLd& rhs) { val = (*this % rhs).val; } +}; template ostream& operator<<(ostream& out, const MLL& num) { return out << num.val; } +ostream& operator<<(ostream& out, const MLLd& num) { + return out << num.val; +} + template istream& operator>>(istream& in, MLL& num) { return in >> num.val; } +istream& operator>>(istream& in, MLLd& num) { + return in >> num.val; +} + // miscancellous template void sort_by_key(RandomIt first, RandomIt last, Func extractor) { std::sort(first, last, [&] (auto&& a, auto&& b) { return std::less<>()(extractor(a), extractor(b)); });