From a6f745a787bb247a1d385e36617b7c6bed43542d Mon Sep 17 00:00:00 2001 From: jie Date: Fri, 5 Mar 2021 15:39:47 +0800 Subject: [PATCH] initial commit --- .gitignore | 5 + CMakeLists.txt | 24 ++++ README.md | 71 ++++++++- bserv.cpp | 374 ++++++++++++++++++++++++++++++++++++++++++++++++ build/README.md | 6 + common.hpp | 10 ++ config.hpp | 115 +++++++++++++++ database.hpp | 303 +++++++++++++++++++++++++++++++++++++++ db.sql | 10 ++ handlers.hpp | 212 +++++++++++++++++++++++++++ logging.hpp | 42 ++++++ router.hpp | 354 +++++++++++++++++++++++++++++++++++++++++++++ routing.hpp | 32 +++++ session.hpp | 110 ++++++++++++++ utils.hpp | 224 +++++++++++++++++++++++++++++ 15 files changed, 1891 insertions(+), 1 deletion(-) create mode 100644 CMakeLists.txt create mode 100644 bserv.cpp create mode 100644 build/README.md create mode 100644 common.hpp create mode 100644 config.hpp create mode 100644 database.hpp create mode 100644 db.sql create mode 100644 handlers.hpp create mode 100644 logging.hpp create mode 100644 router.hpp create mode 100644 routing.hpp create mode 100644 session.hpp create mode 100644 utils.hpp diff --git a/.gitignore b/.gitignore index 259148f..324b3b0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +test/ +build/* +!build/README.md +.* + # Prerequisites *.d diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..03d9e5a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 3.10) + +project(bserv) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +set(CMAKE_CXX_FLAGS "-Wall -Wextra") +set(CMAKE_CXX_FLAGS_DEBUG "-g") +set(CMAKE_CXX_FLAGS_RELEASE "-O3") + +add_executable(bserv bserv.cpp) +target_link_libraries(bserv + pthread + boost_thread + boost_log + boost_log_setup + pqxx + pq + cryptopp) diff --git a/README.md b/README.md index 3c50d01..0ab404b 100644 --- a/README.md +++ b/README.md @@ -1 +1,70 @@ -# bserv \ No newline at end of file +# bserv + +*A Boost Based High Performance C++ HTTP JSON Server.* + + +## Dependencies + +- [Boost 1.75.0](https://www.boost.org/) +- [PostgreSQL 13.2](https://www.postgresql.org/) +- [Libpqxx 7.3.1](https://github.com/jtv/libpqxx) +- [Crypto++ 8.4.0](https://cryptopp.com/) +- CMake + + +## Quick Start + +### Database + +You can import the sample database: + +- Create the database in `psql`: + ``` + create database bserv; + ``` + +- Create the table in the `shell` using a sample script: + ``` + psql bserv < db.sql + ``` + + +### Routing + +Configure routing in [routing.hpp](routing.hpp). + + +### Handlers + +Write the handlers in [handlers.hpp](handlers.hpp) + + +## Build + +Please refer to [this](build/README.md). + + +## Running + +Run in `shell`: +``` +./build/bserv +``` + + +## Performance + +This test is performed by Jmeter. + +The unit for throughput is Transaction per second. + + +|URL|bserv|Java Spring Boot| +|:-:|:-:|:-:| +|`/login`|139.55|| +|`/find/`|958.77|| + + +### Computer Hardware: +- Intel Core i9-9900K x 4 +- 16GB RAM diff --git a/bserv.cpp b/bserv.cpp new file mode 100644 index 0000000..9f2b38b --- /dev/null +++ b/bserv.cpp @@ -0,0 +1,374 @@ +/** + * bserv - Boost-based HTTP Server + * + * reference: + * https://www.boost.org/doc/libs/1_75_0/libs/beast/example/http/server/async/http_server_async.cpp + * https://www.boost.org/doc/libs/1_75_0/libs/beast/example/advanced/server/advanced_server.cpp + * + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "config.hpp" +#include "logging.hpp" +#include "utils.hpp" +#include "routing.hpp" +#include "database.hpp" + +namespace bserv { + +namespace beast = boost::beast; +namespace http = beast::http; +namespace asio = boost::asio; +namespace json = boost::json; +using asio::ip::tcp; + +void fail(const beast::error_code& ec, const char* what) { + lgerror << what << ": " << ec.message() << std::endl; +} + +// this function produces an HTTP response for the given +// request. The type of the response object depends on the +// contents of the request, so the interface requires the +// caller to pass a generic lambda for receiving the response. +// NOTE: `send` should be called only once! +template +void handle_request( + http::request>&& req, + Send&& send) { + + const auto bad_request = [&req](beast::string_view why) { + http::response res{ + http::status::bad_request, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "text/html"); + res.keep_alive(req.keep_alive()); + res.body() = std::string{why}; + res.prepare_payload(); + return res; + }; + + const auto not_found = [&req](beast::string_view target) { + http::response res{ + http::status::not_found, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "text/html"); + res.keep_alive(req.keep_alive()); + res.body() = "The requested url '" + + std::string{target} + "' does not exist."; + res.prepare_payload(); + return res; + }; + + const auto server_error = [&req](beast::string_view what) { + http::response res{ + http::status::internal_server_error, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "text/html"); + res.keep_alive(req.keep_alive()); + res.body() = "Internal server error: " + std::string{what}; + res.prepare_payload(); + return res; + }; + + boost::string_view target = req.target(); + auto pos = target.find('?'); + boost::string_view url; + if (pos == boost::string_view::npos) url = target; + else url = target.substr(0, pos); + + http::response res{ + http::status::ok, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "application/json"); + res.keep_alive(req.keep_alive()); + + std::optional val; + try { + val = routes(std::string{url}, req, res); + } catch (const url_not_found_exception& e) { + send(not_found(url)); + return; + } catch (const bad_request_exception& e) { + send(bad_request("Request body is not a valid JSON string.")); + return; + } catch (const std::exception& e) { + send(server_error(e.what())); + return; + } catch (...) { + send(server_error("Unknown exception.")); + return; + } + + if (val.has_value()) { + res.body() = json::serialize(val.value()); + res.prepare_payload(); + } + + send(std::move(res)); +} + +// std::string get_address(const tcp::socket& socket) { +// tcp::endpoint end_point = socket.remote_endpoint(); +// std::string addr = end_point.address().to_string() +// + ':' + std::to_string(end_point.port()); +// return addr; +// } + +// handles an HTTP server connection +class http_session + : public std::enable_shared_from_this { +private: + // the function object is used to send an HTTP message. + class send_lambda { + private: + http_session& self_; + public: + send_lambda(http_session& self) + : self_{self} {} + template + void operator()( + http::message&& msg) const { + // the lifetime of the message has to extend + // for the duration of the async operation so + // we use a shared_ptr to manage it. + auto sp = std::make_shared< + http::message>( + std::move(msg)); + // stores a type-erased version of the shared + // pointer in the class to keep it alive. + self_.res_ = sp; + // writes the response + http::async_write( + self_.stream_, *sp, + beast::bind_front_handler( + &http_session::on_write, + self_.shared_from_this(), + sp->need_eof())); + } + } lambda_; + beast::tcp_stream stream_; + beast::flat_buffer buffer_; + boost::optional< + http::request_parser> parser_; + std::shared_ptr res_; + // const std::string address_; + void do_read() { + // constructs a new parser for each message + parser_.emplace(); + // applies a reasonable limit to the allowed size + // of the body in bytes to prevent abuse. + parser_->body_limit(PAYLOAD_LIMIT); + // sets the timeout. + stream_.expires_after(std::chrono::seconds(30)); + // reads a request using the parser-oriented interface + http::async_read( + stream_, buffer_, *parser_, + beast::bind_front_handler( + &http_session::on_read, + shared_from_this())); + } + void on_read( + beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + // lgtrace << "received " << bytes_transferred << " byte(s) from: " << address_; + // this means they closed the connection + if (ec == http::error::end_of_stream) { + do_close(); + return; + } + if (ec) { + fail(ec, "http_session async_read"); + return; + } + // handles the request and sends the response + handle_request(parser_->release(), lambda_); + // at this point the parser can be reset + } + void on_write( + bool close, beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + // we're done with the response so delete it + res_.reset(); + if (ec) { + fail(ec, "http_session async_write"); + return; + } + // lgtrace << "sent " << bytes_transferred << " byte(s) to: " << address_; + if (close) { + // this means we should close the connection, usually because + // the response indicated the "Connection: close" semantic. + do_close(); + return; + } + // reads another request + do_read(); + } + void do_close() { + // sends a TCP shutdown + beast::error_code ec; + stream_.socket().shutdown(tcp::socket::shutdown_send, ec); + // at this point the connection is closed gracefully + // lgtrace << "socket connection closed: " << address_; + } +public: + http_session(tcp::socket&& socket) + : lambda_{*this}, stream_{std::move(socket)} /*, + address_{get_address(stream_.socket())} */ { + // lgtrace << "http session opened: " << address_; + } + ~http_session() { + // lgtrace << "http session closed: " << address_; + } + void run() { + asio::dispatch( + stream_.get_executor(), + beast::bind_front_handler( + &http_session::do_read, + shared_from_this())); + } +}; + +// accepts incoming connections and launches the sessions +class listener + : public std::enable_shared_from_this { +private: + asio::io_context& ioc_; + tcp::acceptor acceptor_; + void do_accept() { + acceptor_.async_accept( + asio::make_strand(ioc_), + beast::bind_front_handler( + &listener::on_accept, + shared_from_this())); + } + void on_accept(beast::error_code ec, tcp::socket socket) { + if (ec) { + fail(ec, "listener::acceptor async_accept"); + } else { + // lgtrace << "listener accepts: " << get_address(socket); + std::make_shared( + std::move(socket))->run(); + } + do_accept(); + } +public: + listener( + asio::io_context& ioc, + tcp::endpoint endpoint) + : ioc_{ioc}, + acceptor_{asio::make_strand(ioc)} { + beast::error_code ec; + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + fail(ec, "listener::acceptor open"); + return; + } + acceptor_.set_option( + asio::socket_base::reuse_address(true), ec); + if (ec) { + fail(ec, "listener::acceptor set_option"); + return; + } + acceptor_.bind(endpoint, ec); + if (ec) { + fail(ec, "listener::acceptor bind"); + return; + } + acceptor_.listen( + asio::socket_base::max_listen_connections, ec); + if (ec) { + fail(ec, "listener::acceptor listen"); + return; + } + } + void run() { + asio::dispatch( + acceptor_.get_executor(), + beast::bind_front_handler( + &listener::do_accept, + shared_from_this())); + } +}; + +void show_config() { + lginfo << NAME << " config:" + << "\nport: " << PORT + << "\nthreads: " << NUM_THREADS + << "\ndb-conn: " << NUM_DB_CONN + << "\npayload: " << PAYLOAD_LIMIT / 1024 / 1024 + << "\nrotation: " << LOG_ROTATION_SIZE / 1024 / 1024 + << "\nlog path: " << LOG_PATH + << "\nconn-str: " << DB_CONN_STR << std::endl; +} + +} // bserv + +int main(int argc, char* argv[]) { + using namespace bserv; + if (parse_arguments(argc, argv)) + return EXIT_FAILURE; + init_logging(); + show_config(); + + // some initializations must be done after parsing the arguments + // e.g. database connection + try { + db_conn_mgr = std::make_shared< + db_connection_manager>(DB_CONN_STR, NUM_DB_CONN); + } catch (const std::exception& e) { + lgfatal << "db connection initialization failed: " << e.what() << std::endl; + return EXIT_FAILURE; + } + session_mgr = std::make_shared(); + + // io_context for all I/O + asio::io_context ioc{NUM_THREADS}; + + // creates and launches a listening port + std::make_shared( + ioc, tcp::endpoint{tcp::v4(), PORT})->run(); + + // captures SIGINT and SIGTERM to perform a clean shutdown + asio::signal_set signals{ioc, SIGINT, SIGTERM}; + signals.async_wait( + [&](const boost::system::error_code&, int) { + // stops the `io_context`. This will cause `run()` + // to return immediately, eventually destroying the + // `io_context` and all of the sockets in it. + ioc.stop(); + }); + + lginfo << NAME << " started"; + + // runs the I/O service on the requested number of threads + std::vector v; + v.reserve(NUM_THREADS - 1); + for (int i = 1; i < NUM_THREADS; ++i) + v.emplace_back([&]{ ioc.run(); }); + ioc.run(); + + // if we get here, it means we got a SIGINT or SIGTERM + lginfo << "exiting " << NAME; + + // blocks until all the threads exit + for (auto & t : v) t.join(); + return EXIT_SUCCESS; +} diff --git a/build/README.md b/build/README.md new file mode 100644 index 0000000..8d91ad8 --- /dev/null +++ b/build/README.md @@ -0,0 +1,6 @@ +# Build + +``` +cmake .. +cmake --build . +``` diff --git a/common.hpp b/common.hpp new file mode 100644 index 0000000..ff810c1 --- /dev/null +++ b/common.hpp @@ -0,0 +1,10 @@ +#ifndef _COMMON_HPP +#define _COMMON_HPP + +#include "database.hpp" +#include "session.hpp" +#include "router.hpp" +#include "utils.hpp" +#include "logging.hpp" + +#endif // _COMMON_HPP \ No newline at end of file diff --git a/config.hpp b/config.hpp new file mode 100644 index 0000000..924a8b2 --- /dev/null +++ b/config.hpp @@ -0,0 +1,115 @@ +#ifndef _CONFIG_HPP +#define _CONFIG_HPP + +#include +#include +#include +#include + +namespace bserv { + +const char* NAME = "bserv"; + +unsigned short PORT = 8080; +int NUM_THREADS = 4; +int NUM_DB_CONN = 10; + +std::size_t PAYLOAD_LIMIT = 1 * 1024 * 1024; + +std::size_t LOG_ROTATION_SIZE = 4 * 1024 * 1024; +std::string LOG_PATH = "./log/"; +std::string DB_CONN_STR = "dbname=bserv"; + +void show_usage() { + std::cout << "Usage: " << NAME << " [OPTION...]\n" + << NAME << " is a C++ Boost-based HTTP server.\n\n" + "Example:\n" + << " " << NAME << " -p 8081 --threads 2\n\n" + "Option:\n" + " -h, --help show help and exit\n" + " -p, --port port (default: 8080)\n" + " --threads number of threads (default: 4)\n" + " --num-conn number of database connections (default: 10)\n" + " --payload payload limit for request in mega bytes (default: 1)\n" + " --rotation log rotation size in mega bytes (default: 4)\n" + " --log-path log path (default: ./log/)\n" + " -c, --conn-str connection string (default: dbname=bserv)" + << std::endl; +} + +// returns `true` if error occurs +bool parse_arguments(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + show_usage(); + return true; + } else if (strcmp(argv[i], "-p") == 0 || strcmp(argv[i], "--port") == 0) { + if (i + 1 < argc) { + PORT = atoi(argv[i + 1]); + ++i; + } else { + std::cerr << "Missing value after: " << argv[i] << std::endl; + return true; + } + } else if (strcmp(argv[i], "--threads") == 0) { + if (i + 1 < argc) { + NUM_THREADS = atoi(argv[i + 1]); + ++i; + } else { + std::cerr << "Missing value after: " << argv[i] << std::endl; + return true; + } + } else if (strcmp(argv[i], "--num-conn") == 0) { + if (i + 1 < argc) { + NUM_DB_CONN = atoi(argv[i + 1]); + ++i; + } else { + std::cerr << "Missing value after: " << argv[i] << std::endl; + return true; + } + } else if (strcmp(argv[i], "--payload") == 0) { + if (i + 1 < argc) { + PAYLOAD_LIMIT = atoi(argv[i + 1]) * 1024 * 1024; + ++i; + } else { + std::cerr << "Missing value after: " << argv[i] << std::endl; + return true; + } + } else if (strcmp(argv[i], "--rotation") == 0) { + if (i + 1 < argc) { + LOG_ROTATION_SIZE = atoi(argv[i + 1]) * 1024 * 1024; + ++i; + } else { + std::cerr << "Missing value after: " << argv[i] << std::endl; + return true; + } + } else if (strcmp(argv[i], "--log-path") == 0) { + if (i + 1 < argc) { + LOG_PATH = argv[i + 1]; + if (LOG_PATH.back() != '/') + LOG_PATH += '/'; + ++i; + } else { + std::cerr << "Missing value after: " << argv[i] << std::endl; + return true; + } + } else if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--conn-str") == 0) { + if (i + 1 < argc) { + DB_CONN_STR = argv[i + 1]; + ++i; + } else { + std::cerr << "Missing value after: " << argv[i] << std::endl; + return true; + } + } else { + std::cerr << "Unrecognized option: " << argv[i] << '\n' << std::endl; + show_usage(); + return true; + } + } + return false; +} + +} // bserv + +#endif // _CONFIG_HPP \ No newline at end of file diff --git a/database.hpp b/database.hpp new file mode 100644 index 0000000..5c4235b --- /dev/null +++ b/database.hpp @@ -0,0 +1,303 @@ +#ifndef _DATABASE_HPP +#define _DATABASE_HPP + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace bserv { + +class db_connection_manager; + +class db_connection { +private: + db_connection_manager& mgr_; + std::shared_ptr conn_; +public: + db_connection( + db_connection_manager& mgr, + std::shared_ptr conn) + : mgr_{mgr}, conn_{conn} {} + // non-copiable, non-assignable + db_connection(const db_connection&) = delete; + db_connection& operator=(const db_connection&) = delete; + // during the destruction, it should put itself back to the + // manager's queue + ~db_connection(); + pqxx::connection& get() { return *conn_; } +}; + +// provides the database connection pool functionality +class db_connection_manager { +private: + std::queue> queue_; + // this lock is for manipulating the `queue_` + mutable std::mutex queue_lock_; + // since C++ 17 doesn't provide the semaphore functionality, + // mutex is used to mimic it. (boost provides it) + // if there are no available connections, trying to lock on + // it will cause blocking. + mutable std::mutex counter_lock_; + friend db_connection; +public: + db_connection_manager(const std::string& conn_str, int n) { + for (int i = 0; i < n; ++i) + queue_.emplace( + std::make_shared(conn_str)); + } + // if there are no available database connections, this function + // blocks until there is any; + // otherwise, this function returns a pointer to `db_connection`. + std::shared_ptr get_or_block() { + // `counter_lock_` must be acquired first. + // exchanging this statement with the next will cause dead-lock, + // because if the request is blocked by `counter_lock_`, + // the destructor of `db_connection` will not be able to put + // itself back due to the `queue_lock_` has already been acquired + // by this request! + counter_lock_.lock(); + // `queue_lock_` is acquired so that only one thread will + // modify the `queue_` + std::lock_guard lg{queue_lock_}; + std::shared_ptr conn = queue_.front(); + queue_.pop(); + // if there are no connections in the `queue_`, + // `counter_lock_` remains to be locked + // so that the following requests will be blocked + if (queue_.size() != 0) counter_lock_.unlock(); + return std::make_shared(*this, conn); + } +}; + +db_connection::~db_connection() { + std::lock_guard lg{mgr_.queue_lock_}; + mgr_.queue_.emplace(conn_); + // if this is the first available connection back to the queue, + // `counter_lock_` is unlocked so that the blocked requests will + // be notified + if (mgr_.queue_.size() == 1) + mgr_.counter_lock_.unlock(); +} + +std::shared_ptr db_conn_mgr; + +// ************************************************************************** + +class db_parameter { +public: + virtual ~db_parameter() = default; + virtual std::string get_value(pqxx::work&) = 0; +}; + +class db_name : public db_parameter { +private: + std::string value_; +public: + db_name(const std::string& value) + : value_{value} {} + std::string get_value(pqxx::work& w) { + return w.quote_name(value_); + } +}; + +template +class db_value : public db_parameter { +private: + Type value_; +public: + db_value(const Type& value) + : value_{value} {} + std::string get_value(pqxx::work&) { + return std::to_string(value_); + } +}; + +template <> +class db_value : public db_parameter { +private: + std::string value_; +public: + db_value(const std::string& value) + : value_{value} {} + std::string get_value(pqxx::work& w) { + return w.quote(value_); + } +}; + +template <> +class db_value : public db_parameter { +private: + bool value_; +public: + db_value(const bool& value) + : value_{value} {} + std::string get_value(pqxx::work&) { + return value_ ? "true" : "false"; + } +}; + +namespace db_internal { + +template +std::shared_ptr convert_parameter( + const Param& param) { + return std::make_shared>(param); +} + +template +std::shared_ptr convert_parameter( + const db_value& param) { + return std::make_shared>(param); +} + +std::shared_ptr convert_parameter( + const char* param) { + return std::make_shared>(param); +} + +std::shared_ptr convert_parameter( + const db_name& param) { + return std::make_shared(param); +} + +template +std::vector convert_parameters( + pqxx::work& w, std::shared_ptr... params) { + return {params->get_value(w)...}; +} + +// ************************************* + +class db_field_holder { +protected: + std::string name_; +public: + db_field_holder(const std::string& name) + : name_{name} {} + virtual ~db_field_holder() = default; + virtual void add( + const pqxx::row& row, size_t field_idx, + boost::json::object& obj) = 0; +}; + +template +class db_field : public db_field_holder { +public: + using db_field_holder::db_field_holder; + void add( + const pqxx::row& row, size_t field_idx, + boost::json::object& obj) { + obj[name_] = row[field_idx].as(); + } +}; + +template <> +class db_field : public db_field_holder { +public: + using db_field_holder::db_field_holder; + void add( + const pqxx::row& row, size_t field_idx, + boost::json::object& obj) { + obj[name_] = row[field_idx].c_str(); + } +}; + +} // db_internal + +template +std::shared_ptr make_db_field( + const std::string& name) { + return std::make_shared>(name); +} + +class invalid_operation_exception : public std::exception { +private: + std::string msg_; +public: + invalid_operation_exception(const std::string& msg) + : msg_{msg} {} + const char* what() const noexcept { return msg_.c_str(); } +}; + +class db_relation_to_object { +private: + std::vector> fields_; +public: + db_relation_to_object( + const std::initializer_list< + std::shared_ptr>& fields) + : fields_{fields} {} + boost::json::object convert_row(const pqxx::row& row) { + boost::json::object obj; + for (size_t i = 0; i < fields_.size(); ++i) + fields_[i]->add(row, i, obj); + return obj; + } + std::vector convert_to_vector( + const pqxx::result& result) { + std::vector results; + for (const auto& row : result) + results.emplace_back(convert_row(row)); + return results; + } + std::optional convert_to_optional( + const pqxx::result& result) { + if (result.size() == 0) return std::nullopt; + if (result.size() == 1) return convert_row(result[0]); + // result.size() > 1 + throw invalid_operation_exception{ + "too many objects to convert"}; + } +}; + +// Usage: +// db_exec(tx, "select * from ? where ? = ? and first_name = 'Name??'", +// db_name("auth_user"), db_name("is_active"), db_value(true)); +// -> SQL: select * from "auth_user" where "is_active" = true and first_name = 'Name?' +// ====================================================================================== +// db_exec(tx, "select * from ? where ? = ? and first_name = ?", +// db_name("auth_user"), db_name("is_active"), false, "Name??"); +// -> SQL: select * from "auth_user" where "is_active" = false and first_name = 'Name??' +// ====================================================================================== +// Note: "?" is the placeholder for parameters, and "??" will be converted to "?" in SQL. +// But, "??" in the parameters remains. +template +pqxx::result db_exec(pqxx::work& w, + const std::string& s, const Params&... params) { + std::vector param_vec = + db_internal::convert_parameters( + w, db_internal::convert_parameter(params)...); + size_t idx = 0; + std::string query; + for (size_t i = 0; i < s.length(); ++i) { + if (s[i] == '?') { + if (i + 1 < s.length() && s[i + 1] == '?') { + query += s[++i]; + } else { + if (idx < param_vec.size()) { + query += param_vec[idx++]; + } else throw std::out_of_range{"too few parameters"}; + } + } else query += s[i]; + } + if (idx != param_vec.size()) + throw invalid_operation_exception{"too many parameters"}; + return w.exec(query); +} + + +// TODO: add support for time conversions between postgresql and c++, use timestamp? +// what about time zone? + +} // bserv + +#endif // _DATABASE_HPP \ No newline at end of file diff --git a/db.sql b/db.sql new file mode 100644 index 0000000..7eca78a --- /dev/null +++ b/db.sql @@ -0,0 +1,10 @@ +CREATE TABLE auth_user ( + id serial PRIMARY KEY, + username character varying(255) NOT NULL UNIQUE, + password character varying(255) NOT NULL, + is_superuser boolean NOT NULL, + first_name character varying(255) NOT NULL, + last_name character varying(255) NOT NULL, + email character varying(255) NOT NULL, + is_active boolean NOT NULL +); diff --git a/handlers.hpp b/handlers.hpp new file mode 100644 index 0000000..5224cf1 --- /dev/null +++ b/handlers.hpp @@ -0,0 +1,212 @@ +#ifndef _HANDLERS_HPP +#define _HANDLERS_HPP + +#include + +#include +#include +#include +#include + +#include + +#include "common.hpp" + +// register an orm mapping (to convert the db query results into +// json objects). +// the db query results contain several rows, each has a number of +// fields. the order of `make_db_field(name[i])` in the +// initializer list corresponds to these fields (`Type[0]` and +// `name[0]` correspond to field[0], `Type[1]` and `name[1]` +// correspond to field[1], ...). `Type[i]` is the type you want +// to convert the field value to, and `name[i]` is the identifier +// with which you want to store the field in the json object, so +// if the returned json object is `obj`, `obj[name[i]]` will have +// the type of `Type[i]` and store the value of field[i]. +bserv::db_relation_to_object orm_user{ + bserv::make_db_field("id"), + bserv::make_db_field("username"), + bserv::make_db_field("password"), + bserv::make_db_field("is_superuser"), + bserv::make_db_field("first_name"), + bserv::make_db_field("last_name"), + bserv::make_db_field("email"), + bserv::make_db_field("is_active") +}; + +std::optional get_user( + pqxx::work& tx, + const std::string& username) { + pqxx::result r = bserv::db_exec(tx, + "select * from auth_user where username = ?", username); + lginfo << r.query(); // this is how you log info + return orm_user.convert_to_optional(r); +} + +std::string get_or_empty( + boost::json::object& obj, + const std::string& key) { + return obj.count(key) ? obj[key].as_string().c_str() : ""; +} + +// if you want to manually modify the response, +// the return type should be `std::nullopt_t`, +// and the return value should be `std::nullopt`. +std::nullopt_t hello( + bserv::response_type& response, + std::shared_ptr session_ptr) { + bserv::session_type& session = *session_ptr; + boost::json::object obj; + if (session.count("user")) { + auto& user = session["user"].as_object(); + obj = { + {"msg", std::string{"welcome, "} + + user["username"].as_string().c_str() + "!"} + }; + } else { + obj = {{"msg", "hello, world!"}}; + } + // the response body is a string, + // so the `obj` should be serialized + response.body() = boost::json::serialize(obj); + response.prepare_payload(); // this line is important! + return std::nullopt; +} + +// if you return a json object, the serialization +// is performed automatically. +boost::json::object user_register( + bserv::request_type& request, + // the json object is obtained from the request body, + // as well as the url parameters + boost::json::object&& params, + std::shared_ptr conn) { + if (request.method() != boost::beast::http::verb::post) { + throw bserv::url_not_found_exception{}; + } + if (params.count("username") == 0) { + return { + {"success", false}, + {"message", "`username` is required"} + }; + } + if (params.count("password") == 0) { + return { + {"success", false}, + {"message", "`password` is required"} + }; + } + auto username = params["username"].as_string(); + pqxx::work tx{conn->get()}; + auto opt_user = get_user(tx, username.c_str()); + if (opt_user.has_value()) { + return { + {"success", false}, + {"message", "`username` existed"} + }; + } + auto password = params["password"].as_string(); + pqxx::result r = bserv::db_exec(tx, + "insert into ? " + "(?, password, is_superuser, " + "first_name, last_name, email, is_active) values " + "(?, ?, ?, ?, ?, ?, ?)", bserv::db_name("auth_user"), + bserv::db_name("username"), + username.c_str(), + bserv::utils::security::encode_password( + password.c_str()), false, + get_or_empty(params, "first_name"), + get_or_empty(params, "last_name"), + get_or_empty(params, "email"), true); + lginfo << r.query(); + tx.commit(); // you must manually commit changes + return { + {"success", true}, + {"message", "user registered"} + }; +} + +boost::json::object user_login( + bserv::request_type& request, + boost::json::object&& params, + std::shared_ptr conn, + std::shared_ptr session_ptr) { + if (request.method() != boost::beast::http::verb::post) { + throw bserv::url_not_found_exception{}; + } + if (params.count("username") == 0) { + return { + {"success", false}, + {"message", "`username` is required"} + }; + } + if (params.count("password") == 0) { + return { + {"success", false}, + {"message", "`password` is required"} + }; + } + auto username = params["username"].as_string(); + pqxx::work tx{conn->get()}; + auto opt_user = get_user(tx, username.c_str()); + if (!opt_user.has_value()) { + return { + {"success", false}, + {"message", "invalid username/password"} + }; + } + auto& user = opt_user.value(); + if (!user["is_active"].as_bool()) { + return { + {"success", false}, + {"message", "invalid username/password"} + }; + } + auto password = params["password"].as_string(); + auto encoded_password = user["password"].as_string(); + if (!bserv::utils::security::check_password( + password.c_str(), encoded_password.c_str())) { + return { + {"success", false}, + {"message", "invalid username/password"} + }; + } + bserv::session_type& session = *session_ptr; + session["user"] = user; + return { + {"success", true}, + {"message", "login successfully"} + }; +} + +boost::json::object find_user( + std::shared_ptr conn, + const std::string& username) { + pqxx::work tx{conn->get()}; + auto user = get_user(tx, username); + if (!user.has_value()) { + return { + {"success", false}, + {"message", "requested user does not exist"} + }; + } + user.value().erase("password"); + return { + {"success", true}, + {"user", user.value()} + }; +} + +boost::json::object user_logout( + std::shared_ptr session_ptr) { + bserv::session_type& session = *session_ptr; + if (session.count("user")) { + session.erase("user"); + } + return { + {"success", true}, + {"message", "logout successfully"} + }; +} + +#endif // _HANDLERS_HPP \ No newline at end of file diff --git a/logging.hpp b/logging.hpp new file mode 100644 index 0000000..24ecec1 --- /dev/null +++ b/logging.hpp @@ -0,0 +1,42 @@ +#ifndef _LOGGING_HPP +#define _LOGGING_HPP + +#define BOOST_LOG_DYN_LINK + +#include +#include +#include +#include + +#include "config.hpp" + +namespace bserv { + +namespace logging = boost::log; +namespace keywords = boost::log::keywords; +namespace src = boost::log::sources; + +// this function should be called in `main` +// right after the configurations are loaded. +void init_logging() { + logging::add_file_log( + keywords::file_name = LOG_PATH + NAME + "_%Y%m%d_%H-%M-%S.%N.log", + keywords::rotation_size = LOG_ROTATION_SIZE, + keywords::format = "[%Severity%][%TimeStamp%][%ThreadID%]: %Message%" + ); + logging::core::get()->set_filter( + logging::trivial::severity >= logging::trivial::trace + ); + logging::add_common_attributes(); +} + +#define lgtrace BOOST_LOG_TRIVIAL(trace) +#define lgdebug BOOST_LOG_TRIVIAL(debug) +#define lginfo BOOST_LOG_TRIVIAL(info) +#define lgwarning BOOST_LOG_TRIVIAL(warning) +#define lgerror BOOST_LOG_TRIVIAL(error) +#define lgfatal BOOST_LOG_TRIVIAL(fatal) + +} // bserv + +#endif // _LOGGING_HPP \ No newline at end of file diff --git a/router.hpp b/router.hpp new file mode 100644 index 0000000..b4d7156 --- /dev/null +++ b/router.hpp @@ -0,0 +1,354 @@ +#ifndef _ROUTER_HPP +#define _ROUTER_HPP + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "database.hpp" +#include "session.hpp" +#include "utils.hpp" +#include "config.hpp" + +namespace bserv { + +namespace beast = boost::beast; +namespace http = beast::http; + +using request_type = http::request; +using response_type = http::response; + +namespace placeholders { + +template +struct placeholder {}; + +#define make_place_holder(x) constexpr placeholder _##x + +make_place_holder(1); +make_place_holder(2); +make_place_holder(3); +make_place_holder(4); +make_place_holder(5); +make_place_holder(6); +make_place_holder(7); +make_place_holder(8); +make_place_holder(9); + +#undef make_place_holder + +// std::shared_ptr +constexpr placeholder<-1> session; +// bserv::request_type& +constexpr placeholder<-2> request; +// bserv::response_type& +constexpr placeholder<-3> response; +// boost::json::object&& +constexpr placeholder<-4> json_params; +// std::shared_ptr +constexpr placeholder<-5> transaction; + +} // placeholders + +class bad_request_exception : public std::exception { +public: + bad_request_exception() = default; + const char* what() const noexcept { return "bad request"; } +}; + +namespace router_internal { + +template +struct parameter_pack; + +template <> +struct parameter_pack<> {}; + +template +struct parameter_pack + : parameter_pack { + Head head_; + template + parameter_pack(Head2&& head, Tail2&& ...tail) + : parameter_pack{static_cast(tail)...}, + head_{static_cast(head)} {} +}; + +template +struct get_parameter_pack; + +template +struct get_parameter_pack + : get_parameter_pack {}; + +template +struct get_parameter_pack<0, Head, Tail...> { + using type = parameter_pack; +}; + +template +decltype(auto) get_parameter_value(parameter_pack& params) { + return (static_cast< + typename get_parameter_pack::type& + >(params).head_); +} + +template +struct get_parameter; + +template +struct get_parameter + : get_parameter {}; + +template +struct get_parameter<0, Head, Tail...> { + using type = Head; +}; + +template +Type&& get_parameter_data( + const std::vector&, + request_type&, response_type&, Type&& val) { + return static_cast(val); +} + +template = 0), int> = 0> +const std::string& get_parameter_data( + const std::vector& url_params, + request_type&, response_type&, + placeholders::placeholder) { + return url_params[N]; +} + +std::shared_ptr get_parameter_data( + const std::vector&, + request_type& request, response_type& response, + placeholders::placeholder<-1>) { + std::string cookie_str{request[http::field::cookie]}; + auto&& [cookie_dict, cookie_list] + = utils::parse_params(cookie_str, 0, ';'); + boost::ignore_unused(cookie_list); + std::string session_id; + if (cookie_dict.count(SESSION_NAME) != 0) { + session_id = cookie_dict[SESSION_NAME]; + } + std::shared_ptr session_ptr; + if (session_mgr->get_or_create(session_id, session_ptr)) { + response.set(http::field::set_cookie, SESSION_NAME + "=" + session_id); + } + return session_ptr; +} + +request_type& get_parameter_data( + const std::vector&, + request_type& request, response_type&, + placeholders::placeholder<-2>) { + return request; +} + +response_type& get_parameter_data( + const std::vector&, + request_type&, response_type& response, + placeholders::placeholder<-3>) { + return response; +} + +boost::json::object get_parameter_data( + const std::vector&, + request_type& request, response_type&, + placeholders::placeholder<-4>) { + std::string target{request.target()}; + auto&& [url, dict_params, list_params] = utils::parse_url(target); + boost::ignore_unused(url); + boost::json::object body; + if (!request.body().empty()) { + try { + body = boost::json::parse(request.body()).as_object(); + } catch (const std::exception& e) { + throw bad_request_exception{}; + } + } + for (auto& [k, v] : dict_params) { + if (!body.contains(k)) { + body[k] = v; + } + } + for (auto& [k, vs] : list_params) { + if (!body.contains(k)) { + boost::json::array a; + for (auto& v : vs) { + a.push_back(boost::json::string{v}); + } + body[k] = a; + } + } + return body; +} + +std::shared_ptr get_parameter_data( + const std::vector&, + request_type&, response_type&, + placeholders::placeholder<-5>) { + return db_conn_mgr->get_or_block(); +} + +template +struct path_handler; + +template +struct path_handler> { + Ret invoke(Ret (*pf)(Args ...), parameter_pack& params, + const std::vector& url_params, + request_type& request, response_type& response) { + if constexpr (Idx == 0) return (*pf)(); + else return static_cast, + typename get_parameter::type>* + >(this)->invoke2(pf, params, url_params, request, response, + get_parameter_data(url_params, request, response, + get_parameter_value(params))); + } +}; + +template +struct path_handler, Head, Tail...> + : path_handler, Tail...> { + template < + typename Head2, typename ...Tail2, + std::enable_if_t = 0> + Ret invoke2(Ret (*pf)(Args ...), parameter_pack& params, + const std::vector& url_params, + request_type& request, response_type& response, + Head2&& head2, Tail2&& ...tail2) { + if constexpr (Idx == 0) + return (*pf)(static_cast(head2), + static_cast(tail2)...); + else return static_cast, + typename get_parameter::type, Head, Tail...>* + >(this)->invoke2(pf, params, url_params, request, response, + get_parameter_data(url_params, request, response, + get_parameter_value(params)), + static_cast(head2), static_cast(tail2)...); + } +}; + +const std::vector> url_regex_mapping{ + {std::regex{""}, "([0-9]+)"}, + {std::regex{""}, R"(([A-Za-z0-9_\.\-]+))"}, + {std::regex{""}, R"(([A-Za-z0-9_/\.\-]+))"} +}; + +std::string get_re_url(const std::string& url) { + std::string re_url = url; + for (auto& [r, s] : url_regex_mapping) + re_url = std::regex_replace(re_url, r, s); + return re_url; +} + +struct path_holder : std::enable_shared_from_this { + path_holder() = default; + virtual ~path_holder() = default; + virtual bool match( + const std::string&, + std::vector&) const = 0; + virtual std::optional invoke( + const std::vector&, + request_type&, response_type&) = 0; +}; + +template +class path; + +template +class path> + : public path_holder { +private: + std::regex re_; + Ret (*pf_)(Args ...); + parameter_pack params_; + path_handler<0, Ret (*)(Args ...), parameter_pack, Params...> handler_; +public: + path(const std::string& url, Ret (*pf)(Args ...), Params&& ...params) + : re_{get_re_url(url)}, pf_{pf}, + params_{static_cast(params)...} {} + bool match(const std::string& url, std::vector& result) const { + std::smatch r; + bool matched = std::regex_match(url, r, re_); + if (matched) { + result.clear(); + for (auto & sub : r) + result.push_back(sub.str()); + } + return matched; + } + std::optional invoke( + const std::vector& url_params, + request_type& request, response_type& response) { + return handler_.invoke( + pf_, params_, url_params, + request, response); + } +}; + +} // router_internal + +template +std::shared_ptr>> make_path( + const std::string& url, Ret (*pf)(Args ...), Params&& ...params) { + return std::make_shared< + router_internal::path> + >(url, pf, static_cast(params)...); +} + +template +std::shared_ptr>> make_path( + const char* url, Ret (*pf)(Args ...), Params&& ...params) { + return std::make_shared< + router_internal::path> + >(url, pf, static_cast(params)...); +} + +class url_not_found_exception : public std::exception { +public: + url_not_found_exception() = default; + const char* what() const noexcept { return "url not found"; } +}; + +class router { +private: + using path_holder_type = std::shared_ptr; + std::vector paths_; +public: + router(const std::initializer_list& paths) + : paths_{paths} {} + std::optional operator()( + const std::string& url, request_type& request, response_type& response) { + std::vector url_params; + for (auto& ptr : paths_) { + if (ptr->match(url, url_params)) + return ptr->invoke(url_params, request, response); + } + throw url_not_found_exception{}; + } +}; + +} // bserv + +#endif // _ROUTER_HPP \ No newline at end of file diff --git a/routing.hpp b/routing.hpp new file mode 100644 index 0000000..1814ba9 --- /dev/null +++ b/routing.hpp @@ -0,0 +1,32 @@ +#ifndef _ROUTING_HPP +#define _ROUTING_HPP + +#include "router.hpp" + +#include "handlers.hpp" + +namespace bserv { + +bserv::router routes{ + bserv::make_path("/", &hello, + bserv::placeholders::response, + bserv::placeholders::session), + bserv::make_path("/register", &user_register, + bserv::placeholders::request, + bserv::placeholders::json_params, + bserv::placeholders::transaction), + bserv::make_path("/login", &user_login, + bserv::placeholders::request, + bserv::placeholders::json_params, + bserv::placeholders::transaction, + bserv::placeholders::session), + bserv::make_path("/logout", &user_logout, + bserv::placeholders::session), + bserv::make_path("/find/", &find_user, + bserv::placeholders::transaction, + bserv::placeholders::_1) +}; + +} // bserv + +#endif // _ROUTING_HPP \ No newline at end of file diff --git a/session.hpp b/session.hpp new file mode 100644 index 0000000..a150cad --- /dev/null +++ b/session.hpp @@ -0,0 +1,110 @@ +#ifndef _SESSION_HPP +#define _SESSION_HPP + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" + +namespace bserv { + +const std::string SESSION_NAME = "bsessionid"; + +// using session_type = std::map; +using session_type = boost::json::object; + +struct session_base + : std::enable_shared_from_this { + virtual ~session_base() = default; + // if `key` refers to an existing session, that session will be placed in + // `session_ptr` and this function will return `false`. + // otherwise, this function will create a new session, place the created + // session in `session_ptr`, place the session id in `key`, and return `true`. + // this means, the returned value indicates whether a new session is created, + // the `session_ptr` will point to a session with `key` as its session id, + // after this function is called. + // NOTE: a `shared_ptr` is returned instead of a reference. + virtual bool get_or_create( + std::string& key, + std::shared_ptr& session_ptr) = 0; +}; + +std::shared_ptr session_mgr; + +class memory_session : public session_base { +private: + using time_point = std::chrono::steady_clock::time_point; + std::mt19937 rng_; + std::uniform_int_distribution dist_; + std::map str_to_int_; + std::map int_to_str_; + std::map> sessions_; + // `expiry` stores tuple sorted by key + std::map expiry_; + // `queue` functions as a priority queue + // (the front element is the smallest) + // and stores tuple sorted by + // expiry first and then key. + std::set> queue_; + mutable std::mutex lock_; +public: + memory_session() + : rng_{utils::internal::get_rd_value()}, + dist_{0, std::numeric_limits::max()} {} + bool get_or_create( + std::string& key, + std::shared_ptr& session_ptr) { + std::lock_guard lg{lock_}; + time_point now = std::chrono::steady_clock::now(); + // removes the expired sessions + while (!queue_.empty() && queue_.begin()->first < now) { + std::size_t another_key = queue_.begin()->second; + sessions_.erase(another_key); + expiry_.erase(another_key); + str_to_int_.erase(int_to_str_[another_key]); + int_to_str_.erase(another_key); + queue_.erase(queue_.begin()); + } + bool created = false; + std::size_t int_key; + if (key.empty() || str_to_int_.count(key) == 0) { + do { + key = utils::generate_random_string(32); + } while (str_to_int_.count(key) != 0); + do { + int_key = dist_(rng_); + } while (int_to_str_.count(int_key) != 0); + str_to_int_[key] = int_key; + int_to_str_[int_key] = key; + sessions_[int_key] = std::make_shared(); + created = true; + } else { + int_key = str_to_int_[key]; + queue_.erase( + queue_.lower_bound( + std::make_pair(expiry_[int_key], int_key))); + } + // the expiry is set to be 20 minutes from now. + // if the session is re-visited within 20 minutes, + // the expiry will be extended. + expiry_[int_key] = now + std::chrono::minutes(20); + // pushes expiry-key tuple (pair) to the queue + queue_.emplace(expiry_[int_key], int_key); + session_ptr = sessions_[int_key]; + return created; + } +}; + +} // bserv + +#endif // _SESSION_HPP \ No newline at end of file diff --git a/utils.hpp b/utils.hpp new file mode 100644 index 0000000..f135fe3 --- /dev/null +++ b/utils.hpp @@ -0,0 +1,224 @@ +#ifndef _UTILS_HPP +#define _UTILS_HPP + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace bserv::utils { + +namespace internal { + +// NOTE: +// - `random_device` is implementation dependent. +// it doesn't work with GNU GCC on Windows. +// - for thread-safety, do not directly use it. +// use `get_rd_value` instead. +std::random_device rd; +std::mutex rd_mutex; + +auto get_rd_value() { + std::lock_guard lg{rd_mutex}; + return rd(); +} + +// const std::string chars = "abcdefghijklmnopqrstuvwxyz" +// "ABCDEFGHIJKLMNOPQRSTUVWXYZ" +// "1234567890" +// "!@#$%^&*()" +// "`~-_=+[{]}\\|;:'\",<.>/? "; + +const std::string chars = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "1234567890"; + +} // internal + +// https://www.boost.org/doc/libs/1_75_0/libs/random/example/password.cpp +std::string generate_random_string(std::size_t len) { + std::string s; + std::mt19937 rng{internal::get_rd_value()}; + std::uniform_int_distribution<> dist{0, (int) internal::chars.length() - 1}; + for (std::size_t i = 0; i < len; ++i) s += internal::chars[dist(rng)]; + return s; +} + +namespace security { + +// https://codahale.com/a-lesson-in-timing-attacks/ +bool constant_time_compare(const std::string& a, const std::string& b) { + if (a.length() != b.length()) + return false; + int result = 0; + for (std::size_t i = 0; i < a.length(); ++i) + result |= a[i] ^ b[i]; + return result == 0; +} + +// https://cryptopp.com/wiki/PKCS5_PBKDF2_HMAC +std::string hash_password( + const std::string& password, + const std::string& salt, + unsigned int iterations = 20000 /*320000*/) { + using namespace CryptoPP; + byte derived[SHA256::DIGESTSIZE]; + PKCS5_PBKDF2_HMAC pbkdf; + byte unused = 0; + pbkdf.DeriveKey(derived, sizeof(derived), unused, + (const byte*) password.c_str(), password.length(), + (const byte*) salt.c_str(), salt.length(), + iterations, 0.0f); + std::string result; + Base64Encoder encoder{new StringSink{result}, false}; + encoder.Put(derived, sizeof(derived)); + encoder.MessageEnd(); + return result; +} + +std::string encode_password(const std::string& password) { + std::string salt = generate_random_string(16); + std::string hashed_password = hash_password(password, salt); + return salt + '$' + hashed_password; +} + +bool check_password(const std::string& password, + const std::string& encoded_password) { + std::string salt, hashed_password; + std::string* a = &salt, * b = &hashed_password; + for (std::size_t i = 0; i < encoded_password.length(); ++i) { + if (encoded_password[i] != '$') { + (*a) += encoded_password[i]; + } else { + std::swap(a, b); + } + } + return constant_time_compare( + hash_password(password, salt), hashed_password); +} + +} // security + +// reference for url: +// https://www.ietf.org/rfc/rfc3986.txt + +// reserved = gen-delims / sub-delims +// gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" +// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" +// / "*" / "+" / "," / ";" / "=" + +// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + +// https://stackoverflow.com/questions/54060359/encoding-decoded-urls-in-c +// there can be exceptions (std::stoi)! +std::string decode_url(const std::string& s) { + std::string r; + for (std::size_t i = 0; i < s.length(); ++i) { + if (s[i] == '%') { + int v = std::stoi(s.substr(i + 1, 2), nullptr, 16); + r.push_back(0xff & v); + i += 2; + } else if (s[i] == '+') r.push_back(' '); + else r.push_back(s[i]); + } + return r; +} + +// this function parses param list in the form of k1=v1&k2=v2..., +// where '&' can be any delimiter. +// ki and vi will be converted if they are percent-encoded, +// which is why the returned values are `string`, not `string_view`. +std::pair< + std::map, + std::map>> +parse_params(std::string& s, std::size_t start_pos = 0, char delimiter = '&') { + std::map dict_params; + std::map> list_params; + // we use the swap pointer technique + // we will always append characters to *a only. + std::string key, value, *a = &key, *b = &value; + // append an extra `delimiter` so that the last key-value pair + // is processed just like the other. + s.push_back(delimiter); + for (std::size_t i = start_pos; i < s.length(); ++i) { + if (s[i] == '=') { + std::swap(a, b); + } else if (s[i] == delimiter) { + // swap(a, b); + a = &key; + b = &value; + // prevent ending with ' ' + while (!key.empty() && key.back() == ' ') key.pop_back(); + while (!value.empty() && value.back() == ' ') value.pop_back(); + if (key.empty() && value.empty()) + continue; + key = decode_url(key); + value = decode_url(value); + // if `key` is in `list_params`, append `value`. + auto p = list_params.find(key); + if (p != list_params.end()) { + list_params[key].push_back(value); + } else { // `key` is not in `list_params` + auto p = dict_params.find(key); + // if `key` is in `dict_params`, + // move previous value and `value` to `list_params` + // and remove `key` in `dict_params`. + if (p != dict_params.end()) { + list_params[key] = {p->second, value}; + dict_params.erase(p); + } else { // `key` is not in `dict_params` + dict_params[key] = value; + } + } + // clear `key` and `value` + key = ""; + value = ""; + } else { + // prevent beginning with ' ' + if (a->empty() && s[i] == ' ') { + continue; + } + (*a) += s[i]; + } + } + // remove the last `delimiter` to restore `s` to what it was. + s.pop_back(); + return std::make_pair(dict_params, list_params); +} + +// this function parses url in the form of [url]?k1=v1&k2=v2... +// this function will convert ki and vi if they are percent-encoded. +// NOTE: don't misuse this function, it's going to modify +// the parameter `s` in place! +std::tuple, + std::map>> +parse_url(std::string& s) { + std::string url; + std::size_t i = 0; + for (; i < s.length(); ++i) { + if (s[i] != '?') { + url += s[i]; + } else { + break; + } + } + if (i == s.length()) + return std::make_tuple(url, + std::map{}, + std::map>{}); + auto&& [dict_params, list_params] = parse_params(s, i + 1); + return std::make_tuple(url, dict_params, list_params); +} + +} // bserv::utils + +#endif // _UTILS_HPP \ No newline at end of file