From 3006214897acae414dde117f3b6c342a71a6d3aa Mon Sep 17 00:00:00 2001 From: jie Date: Mon, 9 Aug 2021 15:07:09 +0800 Subject: [PATCH] add wrappers for db --- bserv/database.hpp | 176 ++++++++++++++++++++++++++++++--------------- handlers.hpp | 65 +++++++++-------- scripts/db_test.py | 101 ++++++++++++++++++++++++++ 3 files changed, 257 insertions(+), 85 deletions(-) create mode 100644 scripts/db_test.py diff --git a/bserv/database.hpp b/bserv/database.hpp index 38be01d..ca5d812 100644 --- a/bserv/database.hpp +++ b/bserv/database.hpp @@ -16,16 +16,61 @@ namespace bserv { +using raw_db_connection_type = pqxx::connection; +using raw_db_transaction_type = pqxx::work; + +class db_field { +private: + pqxx::field field_; +public: + db_field(const pqxx::field& field) : field_{field} {} + const char* c_str() const { return field_.c_str(); } + template + Type as() const { return field_.as(); } +}; + +class db_row { +private: + pqxx::row row_; +public: + db_row(const pqxx::row& row) : row_{row} {} + std::size_t size() const { return row_.size(); } + db_field operator[](std::size_t idx) const { return row_[idx]; } +}; + +class db_result { +private: + pqxx::result result_; +public: + class const_iterator { + private: + pqxx::result::const_iterator iterator_; + public: + const_iterator( + const pqxx::result::const_iterator& iterator + ) : iterator_{iterator} {} + const_iterator& operator++() { ++iterator_; return *this; } + bool operator==(const const_iterator& rhs) const { return iterator_ == rhs.iterator_; } + bool operator!=(const const_iterator& rhs) const { return iterator_ != rhs.iterator_; } + db_row operator*() const { return *iterator_; } + }; + db_result() = default; + db_result(const pqxx::result& result) : result_{result} {} + const_iterator begin() const { return result_.begin(); } + const_iterator end() const { return result_.end(); } + std::string query() const { return result_.query(); } +}; + class db_connection_manager; class db_connection { private: db_connection_manager& mgr_; - std::shared_ptr conn_; + std::shared_ptr conn_; public: db_connection( db_connection_manager& mgr, - std::shared_ptr conn) + std::shared_ptr conn) : mgr_{mgr}, conn_{conn} {} // non-copiable, non-assignable db_connection(const db_connection&) = delete; @@ -33,13 +78,13 @@ public: // during the destruction, it should put itself back to the // manager's queue ~db_connection(); - pqxx::connection& get() { return *conn_; } + raw_db_connection_type& get() { return *conn_; } }; // provides the database connection pool functionality class db_connection_manager { private: - std::queue> queue_; + std::queue> queue_; // this lock is for manipulating the `queue_` mutable std::mutex queue_lock_; // since C++ 17 doesn't provide the semaphore functionality, @@ -52,7 +97,7 @@ public: db_connection_manager(const std::string& conn_str, int n) { for (int i = 0; i < n; ++i) queue_.emplace( - std::make_shared(conn_str)); + std::make_shared(conn_str)); } // if there are no available database connections, this function // blocks until there is any; @@ -68,7 +113,7 @@ public: // `queue_lock_` is acquired so that only one thread will // modify the `queue_` std::lock_guard lg{queue_lock_}; - std::shared_ptr conn = queue_.front(); + std::shared_ptr conn = queue_.front(); queue_.pop(); // if there are no connections in the `queue_`, // `counter_lock_` remains to be locked @@ -93,7 +138,7 @@ inline db_connection::~db_connection() { class db_parameter { public: virtual ~db_parameter() = default; - virtual std::string get_value(pqxx::work&) = 0; + virtual std::string get_value(raw_db_transaction_type&) = 0; }; class db_name : public db_parameter { @@ -102,8 +147,8 @@ private: public: db_name(const std::string& value) : value_{value} {} - std::string get_value(pqxx::work& w) { - return w.quote_name(value_); + std::string get_value(raw_db_transaction_type& tx) { + return tx.quote_name(value_); } }; @@ -114,7 +159,7 @@ private: public: db_value(const Type& value) : value_{value} {} - std::string get_value(pqxx::work&) { + std::string get_value(raw_db_transaction_type&) { return std::to_string(value_); } }; @@ -126,8 +171,8 @@ private: public: db_value(const std::string& value) : value_{value} {} - std::string get_value(pqxx::work& w) { - return w.quote(value_); + std::string get_value(raw_db_transaction_type& tx) { + return tx.quote(value_); } }; @@ -138,7 +183,7 @@ private: public: db_value(const bool& value) : value_{value} {} - std::string get_value(pqxx::work&) { + std::string get_value(raw_db_transaction_type&) { return value_ ? "true" : "false"; } }; @@ -169,8 +214,8 @@ inline std::shared_ptr convert_parameter( template std::vector convert_parameters( - pqxx::work& w, std::shared_ptr... params) { - return {params->get_value(w)...}; + raw_db_transaction_type& tx, std::shared_ptr... params) { + return {params->get_value(tx)...}; } // ************************************* @@ -183,7 +228,7 @@ public: : name_{name} {} virtual ~db_field_holder() = default; virtual void add( - const pqxx::row& row, size_t field_idx, + const db_row& row, std::size_t field_idx, boost::json::object& obj) = 0; }; @@ -192,7 +237,7 @@ class db_field : public db_field_holder { public: using db_field_holder::db_field_holder; void add( - const pqxx::row& row, size_t field_idx, + const db_row& row, std::size_t field_idx, boost::json::object& obj) { obj[name_] = row[field_idx].as(); } @@ -203,7 +248,7 @@ class db_field : public db_field_holder { public: using db_field_holder::db_field_holder; void add( - const pqxx::row& row, size_t field_idx, + const db_row& row, std::size_t field_idx, boost::json::object& obj) { obj[name_] = row[field_idx].c_str(); } @@ -234,63 +279,80 @@ public: const std::initializer_list< std::shared_ptr>& fields) : fields_{fields} {} - boost::json::object convert_row(const pqxx::row& row) { + boost::json::object convert_row(const db_row& row) { boost::json::object obj; - for (size_t i = 0; i < fields_.size(); ++i) + for (std::size_t i = 0; i < fields_.size(); ++i) fields_[i]->add(row, i, obj); return obj; } std::vector convert_to_vector( - const pqxx::result& result) { + const db_result& result) { std::vector results; for (const auto& row : result) results.emplace_back(convert_row(row)); return results; } std::optional convert_to_optional( - const pqxx::result& result) { - if (result.size() == 0) return std::nullopt; - if (result.size() == 1) return convert_row(result[0]); + const db_result& result) { + // result.size() == 0 + if (result.begin() == result.end()) return std::nullopt; + auto iterator = result.begin(); + auto first = iterator; + // result.size() == 1 + if (++iterator == result.end()) + return convert_row(*first); // result.size() > 1 throw invalid_operation_exception{ "too many objects to convert"}; } }; -// Usage: -// db_exec(tx, "select * from ? where ? = ? and first_name = 'Name??'", -// db_name("auth_user"), db_name("is_active"), db_value(true)); -// -> SQL: select * from "auth_user" where "is_active" = true and first_name = 'Name?' -// ====================================================================================== -// db_exec(tx, "select * from ? where ? = ? and first_name = ?", -// db_name("auth_user"), db_name("is_active"), false, "Name??"); -// -> SQL: select * from "auth_user" where "is_active" = false and first_name = 'Name??' -// ====================================================================================== -// Note: "?" is the placeholder for parameters, and "??" will be converted to "?" in SQL. -// But, "??" in the parameters remains. -template -pqxx::result db_exec(pqxx::work& w, - const std::string& s, const Params&... params) { - std::vector param_vec = - db_internal::convert_parameters( - w, db_internal::convert_parameter(params)...); - size_t idx = 0; - std::string query; - for (size_t i = 0; i < s.length(); ++i) { - if (s[i] == '?') { - if (i + 1 < s.length() && s[i + 1] == '?') { - query += s[++i]; - } else { - if (idx < param_vec.size()) { - query += param_vec[idx++]; - } else throw std::out_of_range{"too few parameters"}; - } - } else query += s[i]; +class db_transaction { +private: + raw_db_transaction_type tx_; +public: + db_transaction( + std::shared_ptr connection_ptr + ) : tx_{connection_ptr->get()} {} + // non-copiable, non-assignable + db_transaction(const db_transaction&) = delete; + db_transaction& operator=(const db_transaction&) = delete; + // Usage: + // exec("select * from ? where ? = ? and first_name = 'Name??'", + // db_name("auth_user"), db_name("is_active"), db_value(true)); + // -> SQL: select * from "auth_user" where "is_active" = true and first_name = 'Name?' + // ====================================================================================== + // exec("select * from ? where ? = ? and first_name = ?", + // db_name("auth_user"), db_name("is_active"), false, "Name??"); + // -> SQL: select * from "auth_user" where "is_active" = false and first_name = 'Name??' + // ====================================================================================== + // Note: "?" is the placeholder for parameters, and "??" will be converted to "?" in SQL. + // But, "??" in the parameters remains. + template + db_result exec(const std::string& s, const Params&... params) { + std::vector param_vec = + db_internal::convert_parameters( + tx_, db_internal::convert_parameter(params)...); + std::size_t idx = 0; + std::string query; + for (std::size_t i = 0; i < s.length(); ++i) { + if (s[i] == '?') { + if (i + 1 < s.length() && s[i + 1] == '?') { + query += s[++i]; + } else { + if (idx < param_vec.size()) { + query += param_vec[idx++]; + } else throw std::out_of_range{"too few parameters"}; + } + } else query += s[i]; + } + if (idx != param_vec.size()) + throw invalid_operation_exception{"too many parameters"}; + return tx_.exec(query); } - if (idx != param_vec.size()) - throw invalid_operation_exception{"too many parameters"}; - return w.exec(query); -} + void commit() { tx_.commit(); } + void abort() { tx_.abort(); } +}; // TODO: add support for time conversions between postgresql and c++, use timestamp? diff --git a/handlers.hpp b/handlers.hpp index 380fe69..5cd5d53 100644 --- a/handlers.hpp +++ b/handlers.hpp @@ -8,8 +8,6 @@ #include #include -#include - #include "bserv/common.hpp" // register an orm mapping (to convert the db query results into @@ -35,17 +33,17 @@ bserv::db_relation_to_object orm_user{ }; std::optional get_user( - pqxx::work& tx, - const std::string& username) { - pqxx::result r = bserv::db_exec(tx, + bserv::db_transaction& tx, + const std::string& username) { + bserv::db_result r = tx.exec( "select * from auth_user where username = ?", username); lginfo << r.query(); // this is how you log info return orm_user.convert_to_optional(r); } std::string get_or_empty( - boost::json::object& obj, - const std::string& key) { + boost::json::object& obj, + const std::string& key) { return obj.count(key) ? obj[key].as_string().c_str() : ""; } @@ -53,15 +51,25 @@ std::string get_or_empty( // the return type should be `std::nullopt_t`, // and the return value should be `std::nullopt`. std::nullopt_t hello( - bserv::response_type& response, - std::shared_ptr session_ptr) { + bserv::response_type& response, + std::shared_ptr session_ptr) { bserv::session_type& session = *session_ptr; boost::json::object obj; if (session.count("user")) { + // NOTE: modifications to sessions must be performed + // BEFORE referencing objects in them. this is because + // modifications might invalidate referenced objects. + // in this example, "count" might be added to `session`, + // which should be performed first. + // then `user` can be referenced safely. + if (!session.count("count")) { + session["count"] = 0; + } auto& user = session["user"].as_object(); + session["count"] = session["count"].as_int64() + 1; obj = { - {"msg", std::string{"welcome, "} - + user["username"].as_string().c_str() + "!"} + {"welcome", user["username"]}, + {"count", session["count"]} }; } else { obj = {{"msg", "hello, world!"}}; @@ -76,11 +84,11 @@ std::nullopt_t hello( // if you return a json object, the serialization // is performed automatically. boost::json::object user_register( - bserv::request_type& request, - // the json object is obtained from the request body, - // as well as the url parameters - boost::json::object&& params, - std::shared_ptr conn) { + bserv::request_type& request, + // the json object is obtained from the request body, + // as well as the url parameters + boost::json::object&& params, + std::shared_ptr conn) { if (request.method() != boost::beast::http::verb::post) { throw bserv::url_not_found_exception{}; } @@ -97,7 +105,7 @@ boost::json::object user_register( }; } auto username = params["username"].as_string(); - pqxx::work tx{conn->get()}; + bserv::db_transaction tx{conn}; auto opt_user = get_user(tx, username.c_str()); if (opt_user.has_value()) { return { @@ -106,7 +114,7 @@ boost::json::object user_register( }; } auto password = params["password"].as_string(); - pqxx::result r = bserv::db_exec(tx, + bserv::db_result r = tx.exec( "insert into ? " "(?, password, is_superuser, " "first_name, last_name, email, is_active) values " @@ -127,10 +135,10 @@ boost::json::object user_register( } boost::json::object user_login( - bserv::request_type& request, - boost::json::object&& params, - std::shared_ptr conn, - std::shared_ptr session_ptr) { + bserv::request_type& request, + boost::json::object&& params, + std::shared_ptr conn, + std::shared_ptr session_ptr) { if (request.method() != boost::beast::http::verb::post) { throw bserv::url_not_found_exception{}; } @@ -147,7 +155,7 @@ boost::json::object user_login( }; } auto username = params["username"].as_string(); - pqxx::work tx{conn->get()}; + bserv::db_transaction tx{conn}; auto opt_user = get_user(tx, username.c_str()); if (!opt_user.has_value()) { return { @@ -180,9 +188,9 @@ boost::json::object user_login( } boost::json::object find_user( - std::shared_ptr conn, - const std::string& username) { - pqxx::work tx{conn->get()}; + std::shared_ptr conn, + const std::string& username) { + bserv::db_transaction tx{conn}; auto user = get_user(tx, username); if (!user.has_value()) { return { @@ -190,6 +198,7 @@ boost::json::object find_user( {"message", "requested user does not exist"} }; } + user.value().erase("id"); user.value().erase("password"); return { {"success", true}, @@ -198,7 +207,7 @@ boost::json::object find_user( } boost::json::object user_logout( - std::shared_ptr session_ptr) { + std::shared_ptr session_ptr) { bserv::session_type& session = *session_ptr; if (session.count("user")) { session.erase("user"); @@ -235,7 +244,7 @@ boost::json::object send_request( } boost::json::object echo( - boost::json::object&& params) { + boost::json::object&& params) { return {{"echo", params}}; } diff --git a/scripts/db_test.py b/scripts/db_test.py new file mode 100644 index 0000000..728ebe9 --- /dev/null +++ b/scripts/db_test.py @@ -0,0 +1,101 @@ +import uuid +import string +import secrets +import random + +import requests + +from multiprocessing import Process + +from pprint import pprint + +from time import time + + +char_string = string.ascii_letters + string.digits + +def get_password(n): + return ''.join(secrets.choice(char_string) for _ in range(n)) + +def get_string(n): + return ''.join(random.choice(char_string) for _ in range(n)) + +def create_user(): + return { + "username": str(uuid.uuid4()), + "password": get_password(16), + "first_name": get_string(5), + "last_name": get_string(5), + "email": get_string(5) + "@" + get_string(5) + ".com" + } + +# pprint(create_user()) +# exit() + +def session_test(): + session = requests.session() + user = create_user() + res = session.post("http://localhost:8080").json() + if res != {'msg': 'hello, world!'}: + print('test failed') + # print(res) + res = session.post("http://localhost:8080/register", json=user).json() + if res != {'success': True, 'message': 'user registered'}: + print('test failed') + # print(res) + res = session.post("http://localhost:8080/login", json={ + "username": user["username"], + "password": user["password"] + }).json() + if res != {'success': True, 'message': 'login successfully'}: + print('test failed') + # print(res) + n = random.randint(1, 5) + for i in range(1, n + 1): + res = session.post("http://localhost:8080").json() + if res != {'welcome': user["username"], 'count': i}: + print('test failed') + # print(res) + res = session.post("http://localhost:8080/find/" + user["username"]).json() + if res != { + 'success': True, + 'user': { + 'username': user["username"], + 'is_active': True, + 'is_superuser': False, + 'first_name': user["first_name"], + 'last_name': user["last_name"], + 'email': user["email"] + }}: + print('test failed') + # print(res) + res = session.post("http://localhost:8080/logout").json() + if res != {'success': True, 'message': 'logout successfully'}: + print('test failed') + # print(res) + res = session.post("http://localhost:8080").json() + if res != {'msg': 'hello, world!'}: + print('test failed') + # print(res) + +# session_test() +# exit() + +P = 1000 # number of concurrent processes + +processes = [Process(target=session_test) for i in range(P)] + +print('starting') + +start = time() + +for p in processes: + p.start() + +for p in processes: + p.join() + +end = time() + +print('test ended') +print('elapsed: ', end - start)