#ifndef _DATABASE_HPP #define _DATABASE_HPP #include #include #include #include #include #include #include #include #include #include namespace bserv { using raw_db_connection_type = pqxx::connection; using raw_db_transaction_type = pqxx::work; class db_field { private: pqxx::field field_; public: db_field(const pqxx::field& field) : field_{field} {} const char* c_str() const { return field_.c_str(); } template Type as() const { return field_.as(); } }; class db_row { private: pqxx::row row_; public: db_row(const pqxx::row& row) : row_{row} {} std::size_t size() const { return row_.size(); } db_field operator[](std::size_t idx) const { return row_[idx]; } }; class db_result { private: pqxx::result result_; public: class const_iterator { private: pqxx::result::const_iterator iterator_; public: const_iterator( const pqxx::result::const_iterator& iterator ) : iterator_{iterator} {} const_iterator& operator++() { ++iterator_; return *this; } bool operator==(const const_iterator& rhs) const { return iterator_ == rhs.iterator_; } bool operator!=(const const_iterator& rhs) const { return iterator_ != rhs.iterator_; } db_row operator*() const { return *iterator_; } }; db_result() = default; db_result(const pqxx::result& result) : result_{result} {} const_iterator begin() const { return result_.begin(); } const_iterator end() const { return result_.end(); } std::string query() const { return result_.query(); } }; class db_connection_manager; class db_connection { private: db_connection_manager& mgr_; std::shared_ptr conn_; public: db_connection( db_connection_manager& mgr, std::shared_ptr conn) : mgr_{mgr}, conn_{conn} {} // non-copiable, non-assignable db_connection(const db_connection&) = delete; db_connection& operator=(const db_connection&) = delete; // during the destruction, it should put itself back to the // manager's queue ~db_connection(); raw_db_connection_type& get() { return *conn_; } }; // provides the database connection pool functionality class db_connection_manager { private: std::queue> queue_; // this lock is for manipulating the `queue_` mutable std::mutex queue_lock_; // since C++ 17 doesn't provide the semaphore functionality, // mutex is used to mimic it. (boost provides it) // if there are no available connections, trying to lock on // it will cause blocking. mutable std::mutex counter_lock_; friend db_connection; public: db_connection_manager(const std::string& conn_str, int n) { for (int i = 0; i < n; ++i) queue_.emplace( std::make_shared(conn_str)); } // if there are no available database connections, this function // blocks until there is any; // otherwise, this function returns a pointer to `db_connection`. std::shared_ptr get_or_block() { // `counter_lock_` must be acquired first. // exchanging this statement with the next will cause dead-lock, // because if the request is blocked by `counter_lock_`, // the destructor of `db_connection` will not be able to put // itself back due to the `queue_lock_` has already been acquired // by this request! counter_lock_.lock(); // `queue_lock_` is acquired so that only one thread will // modify the `queue_` std::lock_guard lg{queue_lock_}; std::shared_ptr conn = queue_.front(); queue_.pop(); // if there are no connections in the `queue_`, // `counter_lock_` remains to be locked // so that the following requests will be blocked if (queue_.size() != 0) counter_lock_.unlock(); return std::make_shared(*this, conn); } }; inline db_connection::~db_connection() { std::lock_guard lg{mgr_.queue_lock_}; mgr_.queue_.emplace(conn_); // if this is the first available connection back to the queue, // `counter_lock_` is unlocked so that the blocked requests will // be notified if (mgr_.queue_.size() == 1) mgr_.counter_lock_.unlock(); } // ************************************************************************** class db_parameter { public: virtual ~db_parameter() = default; virtual std::string get_value(raw_db_transaction_type&) = 0; }; class db_name : public db_parameter { private: std::string value_; public: db_name(const std::string& value) : value_{value} {} std::string get_value(raw_db_transaction_type& tx) { return tx.quote_name(value_); } }; template class db_value : public db_parameter { private: Type value_; public: db_value(const Type& value) : value_{value} {} std::string get_value(raw_db_transaction_type&) { return std::to_string(value_); } }; template <> class db_value : public db_parameter { private: std::string value_; public: db_value(const std::string& value) : value_{value} {} std::string get_value(raw_db_transaction_type& tx) { return tx.quote(value_); } }; template <> class db_value : public db_parameter { private: bool value_; public: db_value(const bool& value) : value_{value} {} std::string get_value(raw_db_transaction_type&) { return value_ ? "true" : "false"; } }; namespace db_internal { template std::shared_ptr convert_parameter( const Param& param) { return std::make_shared>(param); } template std::shared_ptr convert_parameter( const db_value& param) { return std::make_shared>(param); } inline std::shared_ptr convert_parameter( const char* param) { return std::make_shared>(param); } inline std::shared_ptr convert_parameter( const db_name& param) { return std::make_shared(param); } template std::vector convert_parameters( raw_db_transaction_type& tx, std::shared_ptr... params) { return {params->get_value(tx)...}; } // ************************************* class db_field_holder { protected: std::string name_; public: db_field_holder(const std::string& name) : name_{name} {} virtual ~db_field_holder() = default; virtual void add( const db_row& row, std::size_t field_idx, boost::json::object& obj) = 0; }; template class db_field : public db_field_holder { public: using db_field_holder::db_field_holder; void add( const db_row& row, std::size_t field_idx, boost::json::object& obj) { obj[name_] = row[field_idx].as(); } }; template <> class db_field : public db_field_holder { public: using db_field_holder::db_field_holder; void add( const db_row& row, std::size_t field_idx, boost::json::object& obj) { obj[name_] = row[field_idx].c_str(); } }; } // db_internal template std::shared_ptr make_db_field( const std::string& name) { return std::make_shared>(name); } class invalid_operation_exception : public std::exception { private: std::string msg_; public: invalid_operation_exception(const std::string& msg) : msg_{msg} {} const char* what() const noexcept { return msg_.c_str(); } }; class db_relation_to_object { private: std::vector> fields_; public: db_relation_to_object( const std::initializer_list< std::shared_ptr>& fields) : fields_{fields} {} boost::json::object convert_row(const db_row& row) { boost::json::object obj; for (std::size_t i = 0; i < fields_.size(); ++i) fields_[i]->add(row, i, obj); return obj; } std::vector convert_to_vector( const db_result& result) { std::vector results; for (const auto& row : result) results.emplace_back(convert_row(row)); return results; } std::optional convert_to_optional( const db_result& result) { // result.size() == 0 if (result.begin() == result.end()) return std::nullopt; auto iterator = result.begin(); auto first = iterator; // result.size() == 1 if (++iterator == result.end()) return convert_row(*first); // result.size() > 1 throw invalid_operation_exception{ "too many objects to convert"}; } }; class db_transaction { private: raw_db_transaction_type tx_; public: db_transaction( std::shared_ptr connection_ptr ) : tx_{connection_ptr->get()} {} // non-copiable, non-assignable db_transaction(const db_transaction&) = delete; db_transaction& operator=(const db_transaction&) = delete; // Usage: // exec("select * from ? where ? = ? and first_name = 'Name??'", // db_name("auth_user"), db_name("is_active"), db_value(true)); // -> SQL: select * from "auth_user" where "is_active" = true and first_name = 'Name?' // ====================================================================================== // exec("select * from ? where ? = ? and first_name = ?", // db_name("auth_user"), db_name("is_active"), false, "Name??"); // -> SQL: select * from "auth_user" where "is_active" = false and first_name = 'Name??' // ====================================================================================== // Note: "?" is the placeholder for parameters, and "??" will be converted to "?" in SQL. // But, "??" in the parameters remains. template db_result exec(const std::string& s, const Params&... params) { std::vector param_vec = db_internal::convert_parameters( tx_, db_internal::convert_parameter(params)...); std::size_t idx = 0; std::string query; for (std::size_t i = 0; i < s.length(); ++i) { if (s[i] == '?') { if (i + 1 < s.length() && s[i + 1] == '?') { query += s[++i]; } else { if (idx < param_vec.size()) { query += param_vec[idx++]; } else throw std::out_of_range{"too few parameters"}; } } else query += s[i]; } if (idx != param_vec.size()) throw invalid_operation_exception{"too many parameters"}; return tx_.exec(query); } void commit() { tx_.commit(); } void abort() { tx_.abort(); } }; // TODO: add support for time conversions between postgresql and c++, use timestamp? // what about time zone? } // bserv #endif // _DATABASE_HPP