From 22c07227d20e17be66af923aaebb511ac8b91d46 Mon Sep 17 00:00:00 2001 From: jie Date: Sat, 7 Aug 2021 20:56:19 +0800 Subject: [PATCH] add websocket support & replace promise/future with coroutine --- CMakeLists.txt | 15 +- bserv/CMakeLists.txt | 26 +++ bserv/client.hpp | 234 +++++++------------ bserv/common.hpp | 1 + bserv/database.hpp | 8 +- bserv/logging.hpp | 4 +- bserv/router.hpp | 74 ++++-- bserv/server.cpp | 503 ++++++++++++++++++++++++++++++++++++++++ bserv/server.hpp | 348 +-------------------------- bserv/session.hpp | 2 +- bserv/utils.hpp | 22 +- bserv/websocket.hpp | 65 ++++++ handlers.hpp | 39 +++- main.cpp | 8 +- scripts/request_test.py | 48 ++++ scripts/ws_test.py | 62 +++++ 16 files changed, 917 insertions(+), 542 deletions(-) create mode 100644 bserv/CMakeLists.txt create mode 100644 bserv/server.cpp create mode 100644 bserv/websocket.hpp create mode 100644 scripts/request_test.py create mode 100644 scripts/ws_test.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 53b4201..c849fb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.10) -project(bserv) +project(bserv_main) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED True) @@ -13,12 +13,7 @@ set(CMAKE_CXX_FLAGS "-Wall -Wextra") set(CMAKE_CXX_FLAGS_DEBUG "-g") set(CMAKE_CXX_FLAGS_RELEASE "-O3") -add_executable(bserv main.cpp) -target_link_libraries(bserv - pthread - boost_thread - boost_log - boost_log_setup - pqxx - pq - cryptopp) +add_subdirectory(bserv) + +add_executable(main main.cpp) +target_link_libraries(main bserv) diff --git a/bserv/CMakeLists.txt b/bserv/CMakeLists.txt new file mode 100644 index 0000000..df97293 --- /dev/null +++ b/bserv/CMakeLists.txt @@ -0,0 +1,26 @@ +cmake_minimum_required(VERSION 3.10) + +project(bserv) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +set(CMAKE_CXX_FLAGS "-Wall -Wextra") +set(CMAKE_CXX_FLAGS_DEBUG "-g") +set(CMAKE_CXX_FLAGS_RELEASE "-O3") + +add_library(bserv server.cpp) +target_link_libraries(bserv + pthread + boost_thread + boost_coroutine + boost_log + boost_log_setup + boost_json + pqxx + pq + cryptopp) diff --git a/bserv/client.hpp b/bserv/client.hpp index 31b2bd1..f9f88c9 100644 --- a/bserv/client.hpp +++ b/bserv/client.hpp @@ -2,8 +2,9 @@ #define _CLIENT_HPP #include +#include #include -#include +#include #include #include @@ -29,139 +30,74 @@ using response_type = http::response; class request_failed_exception : public std::exception { private: - std::string msg_; + const std::string msg_; public: request_failed_exception(const std::string& msg) : msg_{msg} {} const char* what() const noexcept { return msg_.c_str(); } }; // https://www.boost.org/doc/libs/1_75_0/libs/beast/example/http/client/async/http_client_async.cpp +// https://www.boost.org/doc/libs/1_75_0/libs/beast/example/http/client/coro/http_client_coro.cpp // sends one async request to a remote server -template -class http_client_session - : public std::enable_shared_from_this< - http_client_session> { -private: - tcp::resolver resolver_; - beast::tcp_stream stream_; - // must persist between reads - beast::flat_buffer buffer_; - http::request req_; - http::response res_; - std::promise promise_; - void failed(const beast::error_code& ec, const std::string& what) { - promise_.set_exception( - std::make_exception_ptr( - request_failed_exception{what + ": " + ec.message()})); - } -public: - http_client_session( +inline http::response http_client_send( asio::io_context& ioc, - const http::request& req) - : resolver_{asio::make_strand(ioc)}, - stream_{asio::make_strand(ioc)}, req_{req} {} - std::future send( + asio::yield_context& yield, const std::string& host, - const std::string& port) { - resolver_.async_resolve( - host, port, - beast::bind_front_handler( - &http_client_session::on_resolve, - http_client_session::shared_from_this())); - return promise_.get_future(); + const std::string& port, + const http::request& req) { + beast::error_code ec; + tcp::resolver resolver{ioc}; + const auto results = resolver.async_resolve(host, port, yield[ec]); + if (ec) { + throw request_failed_exception{"http_client_session::resolver resolve: " + ec.message()}; } - void on_resolve( - beast::error_code ec, - tcp::resolver::results_type results) { - if (ec) { - failed(ec, "http_client_session::resolver resolve"); - return; - } - // sets a timeout on the operation - stream_.expires_after(std::chrono::seconds(EXPIRY_TIME)); - // makes the connection on the IP address we get from a lookup - stream_.async_connect( - results, - beast::bind_front_handler( - &http_client_session::on_connect, - http_client_session::shared_from_this())); + beast::tcp_stream stream{ioc}; + // sets a timeout on the operation + stream.expires_after(std::chrono::seconds(EXPIRY_TIME)); + // makes the connection on the IP address we get from a lookup + stream.async_connect(results, yield[ec]); + if (ec) { + throw request_failed_exception{"http_client_session::stream connect: " + ec.message()}; } - void on_connect( - beast::error_code ec, - tcp::resolver::results_type::endpoint_type) { - if (ec) { - failed(ec, "http_client_session::stream connect"); - return; - } - // sets a timeout on the operation - stream_.expires_after(std::chrono::seconds(EXPIRY_TIME)); - // sends the HTTP request to the remote host - http::async_write( - stream_, req_, - beast::bind_front_handler( - &http_client_session::on_write, - http_client_session::shared_from_this())); + // sets a timeout on the operation + stream.expires_after(std::chrono::seconds(EXPIRY_TIME)); + // sends the HTTP request to the remote host + http::async_write(stream, req, yield[ec]); + if (ec) { + throw request_failed_exception{"http_client_session::stream write: " + ec.message()}; } - void on_write( - beast::error_code ec, - std::size_t bytes_transferred) { - boost::ignore_unused(bytes_transferred); - if (ec) { - failed(ec, "http_client_session::stream write"); - return; - } - // receives the HTTP response - http::async_read( - stream_, buffer_, res_, - beast::bind_front_handler( - &http_client_session::on_read, - http_client_session::shared_from_this())); + beast::flat_buffer buffer; + http::response res; + // receives the HTTP response + http::async_read(stream, buffer, res, yield[ec]); + if (ec) { + throw request_failed_exception{"http_client_session::stream read: " + ec.message()}; } - static_assert(std::is_same>::value - || std::is_same::value, - "unsupported `ResponseType`"); - void on_read( - beast::error_code ec, - std::size_t bytes_transferred) { - boost::ignore_unused(bytes_transferred); - if (ec) { - failed(ec, "http_client_session::stream read"); - return; - } - if constexpr (std::is_same>::value) { - promise_.set_value(std::move(res_)); - } else if constexpr (std::is_same::value) { - promise_.set_value(boost::json::parse(res_.body())); - } else { // this should never happen - promise_.set_exception( - std::make_exception_ptr( - request_failed_exception{"unsupported `ResponseType`"})); - } - // gracefully close the socket - stream_.socket().shutdown(tcp::socket::shutdown_both, ec); - // `not_connected` happens sometimes so don't bother reporting it - if (ec && ec != beast::errc::not_connected) { - // reports the error to the log! - fail(ec, "http_client_session::stream::socket shutdown"); - return; - } - // if we get here then the connection is closed gracefully + // gracefully close the socket + stream.socket().shutdown(tcp::socket::shutdown_both, ec); + // `not_connected` happens sometimes so don't bother reporting it + if (ec && ec != beast::errc::not_connected) { + // reports the error to the log! + fail(ec, "http_client_session::stream::socket shutdown"); + // return; } -}; + // if we get here then the connection is closed gracefully + return res; +} -request_type get_request( +inline request_type get_request( const std::string& host, const std::string& target, const http::verb& method, - const boost::json::object& obj) { + const boost::json::value& val) { request_type req; req.method(method); req.target(target); req.set(http::field::host, host); req.set(http::field::user_agent, NAME); req.set(http::field::content_type, "application/json"); - req.body() = boost::json::serialize(obj); + req.body() = boost::json::serialize(val); req.prepare_payload(); return req; } @@ -169,99 +105,97 @@ request_type get_request( class http_client { private: asio::io_context& ioc_; + asio::yield_context& yield_; public: - http_client(asio::io_context& ioc) - : ioc_{ioc} {} - std::future> request( + http_client(asio::io_context& ioc, asio::yield_context& yield) + : ioc_{ioc}, yield_{yield} {} + http::response request( const std::string& host, const std::string& port, const http::request& req) { - return std::make_shared< - http_client_session> - >(ioc_, req)->send(host, port); + return http_client_send(ioc_, yield_, host, port, req); } - std::future request_for_object( + boost::json::value request_for_value( const std::string& host, const std::string& port, const http::request& req) { - return std::make_shared< - http_client_session - >(ioc_, req)->send(host, port); + return boost::json::parse(request(host, port, req).body()); } - std::future send( + response_type send( const std::string& host, const std::string& port, const std::string& target, const http::verb& method, - const boost::json::object& obj) { - request_type req = get_request(host, target, method, obj); + const boost::json::value& val) { + request_type req = get_request(host, target, method, val); return request(host, port, req); } - std::future send_for_object( + boost::json::value send_for_value( const std::string& host, const std::string& port, const std::string& target, const http::verb& method, - const boost::json::object& obj) { - request_type req = get_request(host, target, method, obj); - return request_for_object(host, port, req); + const boost::json::value& val) { + request_type req = get_request(host, target, method, val); + return request_for_value(host, port, req); } - std::future get( + + response_type get( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send(host, port, target, http::verb::get, obj); + const boost::json::value& val) { + return send(host, port, target, http::verb::get, val); } - std::future get_for_object( + boost::json::value get_for_value( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send_for_object(host, port, target, http::verb::get, obj); + const boost::json::value& val) { + return send_for_value(host, port, target, http::verb::get, val); } - std::future put( + response_type put( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send(host, port, target, http::verb::put, obj); + const boost::json::value& val) { + return send(host, port, target, http::verb::put, val); } - std::future put_for_object( + boost::json::value put_for_value( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send_for_object(host, port, target, http::verb::put, obj); + const boost::json::value& val) { + return send_for_value(host, port, target, http::verb::put, val); } - std::future post( + response_type post( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send(host, port, target, http::verb::post, obj); + const boost::json::value& val) { + return send(host, port, target, http::verb::post, val); } - std::future post_for_object( + boost::json::value post_for_value( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send_for_object(host, port, target, http::verb::post, obj); + const boost::json::value& val) { + return send_for_value(host, port, target, http::verb::post, val); } - std::future delete_( + response_type delete_( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send(host, port, target, http::verb::delete_, obj); + const boost::json::value& val) { + return send(host, port, target, http::verb::delete_, val); } - std::future delete_for_object( + boost::json::value delete_for_value( const std::string& host, const std::string& port, const std::string& target, - const boost::json::object& obj) { - return send_for_object(host, port, target, http::verb::delete_, obj); + const boost::json::value& val) { + return send_for_value(host, port, target, http::verb::delete_, val); } }; diff --git a/bserv/common.hpp b/bserv/common.hpp index e825308..27bc7e6 100644 --- a/bserv/common.hpp +++ b/bserv/common.hpp @@ -9,5 +9,6 @@ #include "server.hpp" #include "session.hpp" #include "utils.hpp" +#include "websocket.hpp" #endif // _COMMON_HPP \ No newline at end of file diff --git a/bserv/database.hpp b/bserv/database.hpp index 8ecc678..38be01d 100644 --- a/bserv/database.hpp +++ b/bserv/database.hpp @@ -1,7 +1,7 @@ #ifndef _DATABASE_HPP #define _DATABASE_HPP -#include +#include #include #include @@ -78,7 +78,7 @@ public: } }; -db_connection::~db_connection() { +inline db_connection::~db_connection() { std::lock_guard lg{mgr_.queue_lock_}; mgr_.queue_.emplace(conn_); // if this is the first available connection back to the queue, @@ -157,12 +157,12 @@ std::shared_ptr convert_parameter( return std::make_shared>(param); } -std::shared_ptr convert_parameter( +inline std::shared_ptr convert_parameter( const char* param) { return std::make_shared>(param); } -std::shared_ptr convert_parameter( +inline std::shared_ptr convert_parameter( const db_name& param) { return std::make_shared(param); } diff --git a/bserv/logging.hpp b/bserv/logging.hpp index b02bb71..59c6c7d 100644 --- a/bserv/logging.hpp +++ b/bserv/logging.hpp @@ -20,7 +20,7 @@ namespace keywords = boost::log::keywords; namespace src = boost::log::sources; // this function should be called before logging is used -void init_logging(const server_config& config) { +inline void init_logging(const server_config& config) { logging::add_file_log( keywords::file_name = config.get_log_path() + "_%Y%m%d_%H-%M-%S.%N.log", keywords::rotation_size = config.get_log_rotation_size(), @@ -39,7 +39,7 @@ void init_logging(const server_config& config) { #define lgerror BOOST_LOG_TRIVIAL(error) #define lgfatal BOOST_LOG_TRIVIAL(fatal) -void fail(const boost::system::error_code& ec, const char* what) { +inline void fail(const boost::system::error_code& ec, const char* what) { lgerror << what << ": " << ec.message() << std::endl; } diff --git a/bserv/router.hpp b/bserv/router.hpp index 9e5d832..58c1e59 100644 --- a/bserv/router.hpp +++ b/bserv/router.hpp @@ -1,8 +1,10 @@ #ifndef _ROUTER_HPP #define _ROUTER_HPP +#include +#include #include -#include +#include #include #include @@ -19,6 +21,7 @@ #include "session.hpp" #include "utils.hpp" #include "config.hpp" +#include "websocket.hpp" namespace bserv { @@ -28,7 +31,6 @@ namespace http = beast::http; struct server_resources { std::shared_ptr session_mgr; std::shared_ptr db_conn_mgr; - std::shared_ptr http_client_ptr; }; namespace placeholders { @@ -62,6 +64,8 @@ constexpr placeholder<-4> json_params; constexpr placeholder<-5> db_connection_ptr; // std::shared_ptr constexpr placeholder<-6> http_client_ptr; +// std::shared_ptr +constexpr placeholder<-7> websocket_server_ptr; } // placeholders @@ -123,6 +127,8 @@ struct get_parameter<0, Head, Tail...> { template Type&& get_parameter_data( server_resources&, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector&, request_type&, response_type&, Type&& val) { return static_cast(val); @@ -131,14 +137,18 @@ Type&& get_parameter_data( template = 0), int> = 0> const std::string& get_parameter_data( server_resources&, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector& url_params, request_type&, response_type&, placeholders::placeholder) { return url_params[N]; } -std::shared_ptr get_parameter_data( +inline std::shared_ptr get_parameter_data( server_resources& resources, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector&, request_type& request, response_type& response, placeholders::placeholder<-1>) { @@ -157,24 +167,30 @@ std::shared_ptr get_parameter_data( return session_ptr; } -request_type& get_parameter_data( +inline request_type& get_parameter_data( server_resources&, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector&, request_type& request, response_type&, placeholders::placeholder<-2>) { return request; } -response_type& get_parameter_data( +inline response_type& get_parameter_data( server_resources&, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector&, request_type&, response_type& response, placeholders::placeholder<-3>) { return response; } -boost::json::object get_parameter_data( +inline boost::json::object get_parameter_data( server_resources&, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector&, request_type& request, response_type&, placeholders::placeholder<-4>) { @@ -206,20 +222,34 @@ boost::json::object get_parameter_data( return body; } -std::shared_ptr get_parameter_data( +inline std::shared_ptr get_parameter_data( server_resources& resources, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector&, request_type&, response_type&, placeholders::placeholder<-5>) { return resources.db_conn_mgr->get_or_block(); } -std::shared_ptr get_parameter_data( - server_resources& resources, +inline std::shared_ptr get_parameter_data( + server_resources&, + asio::io_context& ioc, asio::yield_context& yield, + std::shared_ptr, const std::vector&, request_type&, response_type&, placeholders::placeholder<-6>) { - return resources.http_client_ptr; + return std::make_shared(ioc, yield); +} + +inline std::shared_ptr get_parameter_data( + server_resources&, + asio::io_context&, asio::yield_context& yield, + std::shared_ptr ws_session, + const std::vector&, + request_type&, response_type&, + placeholders::placeholder<-7>) { + return std::make_shared(*ws_session, yield); } template @@ -228,6 +258,8 @@ struct path_handler; template struct path_handler> { Ret invoke(server_resources& resources, + asio::io_context& ioc, asio::yield_context& yield, + std::shared_ptr ws_session, Ret (*pf)(Args ...), parameter_pack& params, const std::vector& url_params, request_type& request, response_type& response) { @@ -235,8 +267,8 @@ struct path_handler> { else return static_cast, typename get_parameter::type>* - >(this)->invoke2(resources, pf, params, url_params, request, response, - get_parameter_data(resources, url_params, request, response, + >(this)->invoke2(resources, ioc, yield, ws_session, pf, params, url_params, request, response, + get_parameter_data(resources, ioc, yield, ws_session, url_params, request, response, get_parameter_value(params))); } }; @@ -251,6 +283,8 @@ struct path_handler = 0> Ret invoke2(server_resources& resources, + asio::io_context& ioc, asio::yield_context& yield, + std::shared_ptr ws_session, Ret (*pf)(Args ...), parameter_pack& params, const std::vector& url_params, request_type& request, response_type& response, @@ -261,8 +295,8 @@ struct path_handler, typename get_parameter::type, Head, Tail...>* - >(this)->invoke2(resources, pf, params, url_params, request, response, - get_parameter_data(resources, url_params, request, response, + >(this)->invoke2(resources, ioc, yield, ws_session, pf, params, url_params, request, response, + get_parameter_data(resources, ioc, yield, ws_session, url_params, request, response, get_parameter_value(params)), static_cast(head2), static_cast(tail2)...); } @@ -274,7 +308,7 @@ const std::vector> url_regex_mapping{ {std::regex{""}, R"(([A-Za-z0-9_/\.\-]+))"} }; -std::string get_re_url(const std::string& url) { +inline std::string get_re_url(const std::string& url) { std::string re_url = url; for (auto& [r, s] : url_regex_mapping) re_url = std::regex_replace(re_url, r, s); @@ -289,6 +323,8 @@ struct path_holder : std::enable_shared_from_this { std::vector&) const = 0; virtual std::optional invoke( server_resources&, + asio::io_context&, asio::yield_context&, + std::shared_ptr, const std::vector&, request_type&, response_type&) = 0; }; @@ -320,10 +356,12 @@ public: } std::optional invoke( server_resources& resources, + asio::io_context& ioc, asio::yield_context& yield, + std::shared_ptr ws_session, const std::vector& url_params, request_type& request, response_type& response) { return handler_.invoke( - resources, + resources, ioc, yield, ws_session, pf_, params_, url_params, request, response); } @@ -369,11 +407,13 @@ public: resources_ = resources; } std::optional operator()( + asio::io_context& ioc, asio::yield_context& yield, + std::shared_ptr ws_session, const std::string& url, request_type& request, response_type& response) { std::vector url_params; for (auto& ptr : paths_) { if (ptr->match(url, url_params)) - return ptr->invoke(*resources_, url_params, request, response); + return ptr->invoke(*resources_, ioc, yield, ws_session, url_params, request, response); } throw url_not_found_exception{}; } diff --git a/bserv/server.cpp b/bserv/server.cpp new file mode 100644 index 0000000..834eebb --- /dev/null +++ b/bserv/server.cpp @@ -0,0 +1,503 @@ +#include "server.hpp" + +#include "logging.hpp" +#include "utils.hpp" +#include "client.hpp" +#include "websocket.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace bserv { + + +std::string get_address(const tcp::socket& socket) { + tcp::endpoint end_point = socket.remote_endpoint(); + std::string addr = end_point.address().to_string() + + ':' + std::to_string(end_point.port()); + return addr; +} + +http::response handle_request( + http::request& req, router& routes, + std::shared_ptr ws_session, + asio::io_context& ioc, asio::yield_context& yield) { + + const auto bad_request = [&req](beast::string_view why) { + http::response res{ + http::status::bad_request, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "text/html"); + res.keep_alive(req.keep_alive()); + res.body() = std::string{why}; + res.prepare_payload(); + return res; + }; + + const auto not_found = [&req](beast::string_view target) { + http::response res{ + http::status::not_found, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "text/html"); + res.keep_alive(req.keep_alive()); + res.body() = "The requested url '" + + std::string{target} + "' does not exist."; + res.prepare_payload(); + return res; + }; + + const auto server_error = [&req](beast::string_view what) { + http::response res{ + http::status::internal_server_error, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "text/html"); + res.keep_alive(req.keep_alive()); + res.body() = "Internal server error: " + std::string{what}; + res.prepare_payload(); + return res; + }; + + boost::string_view target = req.target(); + auto pos = target.find('?'); + boost::string_view url; + if (pos == boost::string_view::npos) url = target; + else url = target.substr(0, pos); + + http::response res{ + http::status::ok, req.version()}; + res.set(http::field::server, NAME); + res.set(http::field::content_type, "application/json"); + res.keep_alive(req.keep_alive()); + + std::optional val; + try { + val = routes(ioc, yield, ws_session, std::string{url}, req, res); + } catch (const url_not_found_exception& e) { + return not_found(url); + } catch (const bad_request_exception& e) { + return bad_request("Request body is not a valid JSON string."); + } catch (const std::exception& e) { + return server_error(e.what()); + } catch (...) { + return server_error("Unknown exception."); + } + + if (val.has_value()) { + res.body() = json::serialize(val.value()); + res.prepare_payload(); + } + + return res; +} + +class websocket_session_server; + +void handle_websocket_request( + std::shared_ptr, + std::shared_ptr session, + http::request& req, router& routes, + asio::io_context& ioc, asio::yield_context yield); + +class websocket_session_server + : public std::enable_shared_from_this { +private: + friend websocket_server; + std::string address_; + std::shared_ptr session_; + http::request req_; + router& routes_; + void on_accept(beast::error_code ec) { + if (ec) { + fail(ec, "websocket_session_server accept"); + return; + } + // handles request here + asio::spawn( + session_->ioc_, + std::bind( + &handle_websocket_request, + shared_from_this(), + session_, + std::ref(req_), + std::ref(routes_), + std::ref(session_->ioc_), + std::placeholders::_1)); + } +public: + explicit websocket_session_server( + asio::io_context& ioc, + tcp::socket&& socket, + http::request&& req, + router& routes) + : address_{get_address(socket)}, + session_{std::make_shared< + websocket_session>(address_, ioc, std::move(socket))}, + req_{std::move(req)}, routes_{routes} { + lgtrace << "websocket_session_server opened: " << address_; + } + ~websocket_session_server() { + lgtrace << "websocket_session_server closed: " << address_; + } + // starts the asynchronous accept operation + void do_accept() { + // sets suggested timeout settings for the websocket + session_->ws_.set_option( + websocket::stream_base::timeout::suggested( + beast::role_type::server)); + // sets a decorator to change the Server of the handshake + session_->ws_.set_option( + websocket::stream_base::decorator( + [](websocket::response_type& res) { + res.set( + http::field::server, + std::string{BOOST_BEAST_VERSION_STRING} + " websocket-server"); + })); + // accepts the websocket handshake + session_->ws_.async_accept( + req_, + beast::bind_front_handler( + &websocket_session_server::on_accept, + shared_from_this())); + } +}; + +void handle_websocket_request( + std::shared_ptr, + std::shared_ptr session, + http::request& req, router& routes, + asio::io_context& ioc, asio::yield_context yield) { + handle_request(req, routes, session, ioc, yield); +} + +std::string websocket_server::read() { + beast::error_code ec; + beast::flat_buffer buffer; + // reads a message into the buffer + session_.ws_.async_read(buffer, yield_[ec]); + lgtrace << "websocket_server: read from " << session_.address_; + // this indicates that the session was closed + if (ec == websocket::error::closed) { + throw websocket_closed{}; + } + if (ec) { + fail(ec, "websocket_server read"); + throw websocket_io_exception{"websocket_server read: " + ec.message()}; + } + // lgtrace << "websocket_server: received text? " << ws_.got_text() << " from " << address_; + return beast::buffers_to_string(buffer.data()); +} + +void websocket_server::write(const std::string& data) { + beast::error_code ec; + // ws_.text(ws_.got_text()); + session_.ws_.async_write(asio::buffer(data), yield_[ec]); + lgtrace << "websocket_server: write to " << session_.address_; + if (ec) { + fail(ec, "websocket_server write"); + throw websocket_io_exception{"websocket_server write: " + ec.message()}; + } +} + + +class http_session; + +// this function produces an HTTP response for the given +// request. The type of the response object depends on the +// contents of the request, so the interface requires the +// caller to pass a generic lambda for receiving the response. +// NOTE: `send` should be called only once! +template +void handle_http_request( + std::shared_ptr, + http::request req, + Send& send, router& routes, asio::io_context& ioc, asio::yield_context yield) { + send(handle_request(req, routes, nullptr, ioc, yield)); +} + +// handles an HTTP server connection +class http_session + : public std::enable_shared_from_this { +private: + // the function object is used to send an HTTP message. + class send_lambda { + private: + http_session& self_; + public: + send_lambda(http_session& self) + : self_{self} {} + template + void operator()( + http::message&& msg) const { + // the lifetime of the message has to extend + // for the duration of the async operation so + // we use a shared_ptr to manage it. + auto sp = std::make_shared< + http::message>( + std::move(msg)); + // stores a type-erased version of the shared + // pointer in the class to keep it alive. + self_.res_ = sp; + // writes the response + http::async_write( + self_.stream_, *sp, + beast::bind_front_handler( + &http_session::on_write, + self_.shared_from_this(), + sp->need_eof())); + } + } lambda_; + asio::io_context& ioc_; + beast::tcp_stream stream_; + beast::flat_buffer buffer_; + boost::optional< + http::request_parser> parser_; + std::shared_ptr res_; + router& routes_; + router& ws_routes_; + const std::string address_; + void do_read() { + // constructs a new parser for each message + parser_.emplace(); + // applies a reasonable limit to the allowed size + // of the body in bytes to prevent abuse. + parser_->body_limit(PAYLOAD_LIMIT); + // sets the timeout. + stream_.expires_after(std::chrono::seconds(EXPIRY_TIME)); + // reads a request using the parser-oriented interface + http::async_read( + stream_, buffer_, *parser_, + beast::bind_front_handler( + &http_session::on_read, + shared_from_this())); + } + void on_read( + beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + lgtrace << "received " << bytes_transferred << " byte(s) from: " << address_; + // this means they closed the connection + if (ec == http::error::end_of_stream) { + do_close(); + return; + } + if (ec) { + fail(ec, "http_session async_read"); + return; + } + + // sees if it is a websocket upgrade + if (websocket::is_upgrade(parser_->get())) { + // creates a websocket session, transferring ownership + // of both the socket and the http request + std::make_shared( + ioc_, + stream_.release_socket(), + parser_->release(), + ws_routes_ + )->do_accept(); + return; + } + + // handles the request and sends the response + + asio::spawn( + ioc_, + std::bind( + &handle_http_request, + shared_from_this(), + parser_->release(), + std::ref(lambda_), + std::ref(routes_), + std::ref(ioc_), + std::placeholders::_1)); + // handle_request(parser_->release(), lambda_, routes_); + + // at this point the parser can be reset + } + void on_write( + bool close, beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + // we're done with the response so delete it + res_.reset(); + if (ec) { + fail(ec, "http_session async_write"); + return; + } + lgtrace << "sent " << bytes_transferred << " byte(s) to: " << address_; + if (close) { + // this means we should close the connection, usually because + // the response indicated the "Connection: close" semantic. + do_close(); + return; + } + // reads another request + do_read(); + } + void do_close() { + // sends a TCP shutdown + beast::error_code ec; + stream_.socket().shutdown(tcp::socket::shutdown_send, ec); + // at this point the connection is closed gracefully + lgtrace << "socket connection closed: " << address_; + } +public: + http_session( + asio::io_context& ioc, + tcp::socket&& socket, + router& routes, + router& ws_routes) + : lambda_{*this}, + ioc_{ioc}, + stream_{std::move(socket)}, + routes_{routes}, + ws_routes_{ws_routes}, + address_{get_address(stream_.socket())} { + lgtrace << "http session opened: " << address_; + } + ~http_session() { + lgtrace << "http session closed: " << address_; + } + void run() { + asio::dispatch( + stream_.get_executor(), + beast::bind_front_handler( + &http_session::do_read, + shared_from_this())); + } +}; + +// accepts incoming connections and launches the sessions +class listener + : public std::enable_shared_from_this { +private: + asio::io_context& ioc_; + tcp::acceptor acceptor_; + router& routes_; + router& ws_routes_; + void do_accept() { + acceptor_.async_accept( + asio::make_strand(ioc_), + beast::bind_front_handler( + &listener::on_accept, + shared_from_this())); + } + void on_accept(beast::error_code ec, tcp::socket socket) { + if (ec) { + fail(ec, "listener::acceptor async_accept"); + } else { + lgtrace << "listener accepts: " << get_address(socket); + std::make_shared( + ioc_, std::move(socket), routes_, ws_routes_)->run(); + } + do_accept(); + } +public: + listener( + asio::io_context& ioc, + tcp::endpoint endpoint, + router& routes, + router& ws_routes) + : ioc_{ioc}, + acceptor_{asio::make_strand(ioc)}, + routes_{routes}, + ws_routes_{ws_routes} { + beast::error_code ec; + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + fail(ec, "listener::acceptor open"); + exit(EXIT_FAILURE); + return; + } + acceptor_.set_option( + asio::socket_base::reuse_address(true), ec); + if (ec) { + fail(ec, "listener::acceptor set_option"); + exit(EXIT_FAILURE); + return; + } + acceptor_.bind(endpoint, ec); + if (ec) { + fail(ec, "listener::acceptor bind"); + exit(EXIT_FAILURE); + return; + } + acceptor_.listen( + asio::socket_base::max_listen_connections, ec); + if (ec) { + fail(ec, "listener::acceptor listen"); + exit(EXIT_FAILURE); + return; + } + } + void run() { + asio::dispatch( + acceptor_.get_executor(), + beast::bind_front_handler( + &listener::do_accept, + shared_from_this())); + } +}; + + +server::server(const server_config& config, router&& routes, router&& ws_routes) + : ioc_{config.get_num_threads()}, + routes_{std::move(routes)}, + ws_routes_{std::move(ws_routes)} { + init_logging(config); + + // database connection + try { + db_conn_mgr_ = std::make_shared< + db_connection_manager>(config.get_db_conn_str(), config.get_num_db_conn()); + } catch (const std::exception& e) { + lgfatal << "db connection initialization failed: " << e.what() << std::endl; + exit(EXIT_FAILURE); + } + session_mgr_ = std::make_shared(); + + std::shared_ptr resources_ptr = std::make_shared(); + resources_ptr->session_mgr = session_mgr_; + resources_ptr->db_conn_mgr = db_conn_mgr_; + + routes_.set_resources(resources_ptr); + ws_routes_.set_resources(resources_ptr); + + // creates and launches a listening port + std::make_shared( + ioc_, tcp::endpoint{tcp::v4(), config.get_port()}, routes_, ws_routes_)->run(); + + // captures SIGINT and SIGTERM to perform a clean shutdown + asio::signal_set signals{ioc_, SIGINT, SIGTERM}; + signals.async_wait( + [&](const boost::system::error_code&, int) { + // stops the `io_context`. This will cause `run()` + // to return immediately, eventually destroying the + // `io_context` and all of the sockets in it. + ioc_.stop(); + }); + + lginfo << config.get_name() << " started"; + + // runs the I/O service on the requested number of threads + std::vector v; + v.reserve(config.get_num_threads() - 1); + for (int i = 1; i < config.get_num_threads(); ++i) + v.emplace_back([&]{ ioc_.run(); }); + ioc_.run(); + + // if we get here, it means we got a SIGINT or SIGTERM + lginfo << "exiting " << config.get_name(); + + // blocks until all the threads exit + for (auto & t : v) t.join(); +} + +} // bserv \ No newline at end of file diff --git a/bserv/server.hpp b/bserv/server.hpp index d4a8d98..145eb9e 100644 --- a/bserv/server.hpp +++ b/bserv/server.hpp @@ -3,8 +3,12 @@ * * reference: * https://www.boost.org/doc/libs/1_75_0/libs/beast/example/http/server/async/http_server_async.cpp + * https://www.boost.org/doc/libs/1_75_0/libs/beast/example/http/server/coro/http_server_coro.cpp * https://www.boost.org/doc/libs/1_75_0/libs/beast/example/advanced/server/advanced_server.cpp * + * websocket: + * https://www.boost.org/doc/libs/1_75_0/libs/beast/example/websocket/server/async/websocket_server_async.cpp + * */ #ifndef _SERVER_HPP @@ -12,370 +16,36 @@ #include #include +#include #include -#include +#include -#include -#include -#include -#include -#include -#include #include -#include -#include #include "config.hpp" -#include "logging.hpp" -#include "utils.hpp" #include "router.hpp" #include "database.hpp" #include "session.hpp" -#include "client.hpp" namespace bserv { namespace beast = boost::beast; namespace http = beast::http; +namespace websocket = beast::websocket; namespace asio = boost::asio; namespace json = boost::json; using asio::ip::tcp; -// this function produces an HTTP response for the given -// request. The type of the response object depends on the -// contents of the request, so the interface requires the -// caller to pass a generic lambda for receiving the response. -// NOTE: `send` should be called only once! -template -void handle_request( - http::request>&& req, - Send&& send, router& routes) { - - const auto bad_request = [&req](beast::string_view why) { - http::response res{ - http::status::bad_request, req.version()}; - res.set(http::field::server, NAME); - res.set(http::field::content_type, "text/html"); - res.keep_alive(req.keep_alive()); - res.body() = std::string{why}; - res.prepare_payload(); - return res; - }; - - const auto not_found = [&req](beast::string_view target) { - http::response res{ - http::status::not_found, req.version()}; - res.set(http::field::server, NAME); - res.set(http::field::content_type, "text/html"); - res.keep_alive(req.keep_alive()); - res.body() = "The requested url '" - + std::string{target} + "' does not exist."; - res.prepare_payload(); - return res; - }; - - const auto server_error = [&req](beast::string_view what) { - http::response res{ - http::status::internal_server_error, req.version()}; - res.set(http::field::server, NAME); - res.set(http::field::content_type, "text/html"); - res.keep_alive(req.keep_alive()); - res.body() = "Internal server error: " + std::string{what}; - res.prepare_payload(); - return res; - }; - - boost::string_view target = req.target(); - auto pos = target.find('?'); - boost::string_view url; - if (pos == boost::string_view::npos) url = target; - else url = target.substr(0, pos); - - http::response res{ - http::status::ok, req.version()}; - res.set(http::field::server, NAME); - res.set(http::field::content_type, "application/json"); - res.keep_alive(req.keep_alive()); - - std::optional val; - try { - val = routes(std::string{url}, req, res); - } catch (const url_not_found_exception& e) { - send(not_found(url)); - return; - } catch (const bad_request_exception& e) { - send(bad_request("Request body is not a valid JSON string.")); - return; - } catch (const std::exception& e) { - send(server_error(e.what())); - return; - } catch (...) { - send(server_error("Unknown exception.")); - return; - } - - if (val.has_value()) { - res.body() = json::serialize(val.value()); - res.prepare_payload(); - } - - send(std::move(res)); -} - -std::string get_address(const tcp::socket& socket) { - tcp::endpoint end_point = socket.remote_endpoint(); - std::string addr = end_point.address().to_string() - + ':' + std::to_string(end_point.port()); - return addr; -} - -// handles an HTTP server connection -class http_session - : public std::enable_shared_from_this { -private: - // the function object is used to send an HTTP message. - class send_lambda { - private: - http_session& self_; - public: - send_lambda(http_session& self) - : self_{self} {} - template - void operator()( - http::message&& msg) const { - // the lifetime of the message has to extend - // for the duration of the async operation so - // we use a shared_ptr to manage it. - auto sp = std::make_shared< - http::message>( - std::move(msg)); - // stores a type-erased version of the shared - // pointer in the class to keep it alive. - self_.res_ = sp; - // writes the response - http::async_write( - self_.stream_, *sp, - beast::bind_front_handler( - &http_session::on_write, - self_.shared_from_this(), - sp->need_eof())); - } - } lambda_; - beast::tcp_stream stream_; - beast::flat_buffer buffer_; - boost::optional< - http::request_parser> parser_; - std::shared_ptr res_; - router& routes_; - const std::string address_; - void do_read() { - // constructs a new parser for each message - parser_.emplace(); - // applies a reasonable limit to the allowed size - // of the body in bytes to prevent abuse. - parser_->body_limit(PAYLOAD_LIMIT); - // sets the timeout. - stream_.expires_after(std::chrono::seconds(EXPIRY_TIME)); - // reads a request using the parser-oriented interface - http::async_read( - stream_, buffer_, *parser_, - beast::bind_front_handler( - &http_session::on_read, - shared_from_this())); - } - void on_read( - beast::error_code ec, - std::size_t bytes_transferred) { - boost::ignore_unused(bytes_transferred); - lgtrace << "received " << bytes_transferred << " byte(s) from: " << address_; - // this means they closed the connection - if (ec == http::error::end_of_stream) { - do_close(); - return; - } - if (ec) { - fail(ec, "http_session async_read"); - return; - } - // handles the request and sends the response - handle_request(parser_->release(), lambda_, routes_); - // at this point the parser can be reset - } - void on_write( - bool close, beast::error_code ec, - std::size_t bytes_transferred) { - boost::ignore_unused(bytes_transferred); - // we're done with the response so delete it - res_.reset(); - if (ec) { - fail(ec, "http_session async_write"); - return; - } - lgtrace << "sent " << bytes_transferred << " byte(s) to: " << address_; - if (close) { - // this means we should close the connection, usually because - // the response indicated the "Connection: close" semantic. - do_close(); - return; - } - // reads another request - do_read(); - } - void do_close() { - // sends a TCP shutdown - beast::error_code ec; - stream_.socket().shutdown(tcp::socket::shutdown_send, ec); - // at this point the connection is closed gracefully - lgtrace << "socket connection closed: " << address_; - } -public: - http_session(tcp::socket&& socket, router& routes) - : lambda_{*this}, stream_{std::move(socket)}, routes_{routes}, - address_{get_address(stream_.socket())} { - lgtrace << "http session opened: " << address_; - } - ~http_session() { - lgtrace << "http session closed: " << address_; - } - void run() { - asio::dispatch( - stream_.get_executor(), - beast::bind_front_handler( - &http_session::do_read, - shared_from_this())); - } -}; - -// accepts incoming connections and launches the sessions -class listener - : public std::enable_shared_from_this { -private: - asio::io_context& ioc_; - tcp::acceptor acceptor_; - router& routes_; - void do_accept() { - acceptor_.async_accept( - asio::make_strand(ioc_), - beast::bind_front_handler( - &listener::on_accept, - shared_from_this())); - } - void on_accept(beast::error_code ec, tcp::socket socket) { - if (ec) { - fail(ec, "listener::acceptor async_accept"); - } else { - lgtrace << "listener accepts: " << get_address(socket); - std::make_shared( - std::move(socket), routes_)->run(); - } - do_accept(); - } -public: - listener( - asio::io_context& ioc, - tcp::endpoint endpoint, - router& routes) - : ioc_{ioc}, - acceptor_{asio::make_strand(ioc)}, - routes_{routes} { - beast::error_code ec; - acceptor_.open(endpoint.protocol(), ec); - if (ec) { - fail(ec, "listener::acceptor open"); - exit(EXIT_FAILURE); - return; - } - acceptor_.set_option( - asio::socket_base::reuse_address(true), ec); - if (ec) { - fail(ec, "listener::acceptor set_option"); - exit(EXIT_FAILURE); - return; - } - acceptor_.bind(endpoint, ec); - if (ec) { - fail(ec, "listener::acceptor bind"); - exit(EXIT_FAILURE); - return; - } - acceptor_.listen( - asio::socket_base::max_listen_connections, ec); - if (ec) { - fail(ec, "listener::acceptor listen"); - exit(EXIT_FAILURE); - return; - } - } - void run() { - asio::dispatch( - acceptor_.get_executor(), - beast::bind_front_handler( - &listener::do_accept, - shared_from_this())); - } -}; - class server { private: // io_context for all I/O asio::io_context ioc_; router routes_; + router ws_routes_; std::shared_ptr session_mgr_; std::shared_ptr db_conn_mgr_; - std::shared_ptr http_client_ptr_; public: - server(const server_config& config, router&& routes) - : ioc_{config.get_num_threads()}, - routes_{std::move(routes)} { - init_logging(config); - - // database connection - try { - db_conn_mgr_ = std::make_shared< - db_connection_manager>(config.get_db_conn_str(), config.get_num_db_conn()); - } catch (const std::exception& e) { - lgfatal << "db connection initialization failed: " << e.what() << std::endl; - exit(EXIT_FAILURE); - } - session_mgr_ = std::make_shared(); - http_client_ptr_ = std::make_shared(ioc_); - - std::shared_ptr resources_ptr = std::make_shared(); - resources_ptr->session_mgr = session_mgr_; - resources_ptr->db_conn_mgr = db_conn_mgr_; - resources_ptr->http_client_ptr = http_client_ptr_; - - routes_.set_resources(resources_ptr); - - // creates and launches a listening port - std::make_shared( - ioc_, tcp::endpoint{tcp::v4(), config.get_port()}, routes_)->run(); - - // captures SIGINT and SIGTERM to perform a clean shutdown - asio::signal_set signals{ioc_, SIGINT, SIGTERM}; - signals.async_wait( - [&](const boost::system::error_code&, int) { - // stops the `io_context`. This will cause `run()` - // to return immediately, eventually destroying the - // `io_context` and all of the sockets in it. - ioc_.stop(); - }); - - lginfo << config.get_name() << " started"; - - // runs the I/O service on the requested number of threads - std::vector v; - v.reserve(config.get_num_threads() - 1); - for (int i = 1; i < config.get_num_threads(); ++i) - v.emplace_back([&]{ ioc_.run(); }); - ioc_.run(); - - // if we get here, it means we got a SIGINT or SIGTERM - lginfo << "exiting " << config.get_name(); - - // blocks until all the threads exit - for (auto & t : v) t.join(); - } + server(const server_config& config, router&& routes, router&& ws_routes); }; } // bserv diff --git a/bserv/session.hpp b/bserv/session.hpp index e397a14..cad8d4b 100644 --- a/bserv/session.hpp +++ b/bserv/session.hpp @@ -1,7 +1,7 @@ #ifndef _SESSION_HPP #define _SESSION_HPP -#include +#include #include #include diff --git a/bserv/utils.hpp b/bserv/utils.hpp index 97a5d2f..b4ffa6c 100644 --- a/bserv/utils.hpp +++ b/bserv/utils.hpp @@ -25,10 +25,10 @@ namespace internal { // it doesn't work with GNU GCC on Windows. // - for thread-safety, do not directly use it. // use `get_rd_value` instead. -std::random_device rd; -std::mutex rd_mutex; +inline std::random_device rd; +inline std::mutex rd_mutex; -auto get_rd_value() { +inline auto get_rd_value() { std::lock_guard lg{rd_mutex}; return rd(); } @@ -51,7 +51,7 @@ const std::string url_safe_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" } // internal // https://www.boost.org/doc/libs/1_75_0/libs/random/example/password.cpp -std::string generate_random_string(std::size_t len) { +inline std::string generate_random_string(std::size_t len) { std::string s; std::mt19937 rng{internal::get_rd_value()}; std::uniform_int_distribution<> dist{0, (int) internal::chars.length() - 1}; @@ -62,7 +62,7 @@ std::string generate_random_string(std::size_t len) { namespace security { // https://codahale.com/a-lesson-in-timing-attacks/ -bool constant_time_compare(const std::string& a, const std::string& b) { +inline bool constant_time_compare(const std::string& a, const std::string& b) { if (a.length() != b.length()) return false; int result = 0; @@ -72,7 +72,7 @@ bool constant_time_compare(const std::string& a, const std::string& b) { } // https://cryptopp.com/wiki/PKCS5_PBKDF2_HMAC -std::string hash_password( +inline std::string hash_password( const std::string& password, const std::string& salt, unsigned int iterations = 20000 /*320000*/) { @@ -91,13 +91,13 @@ std::string hash_password( return result; } -std::string encode_password(const std::string& password) { +inline std::string encode_password(const std::string& password) { std::string salt = generate_random_string(16); std::string hashed_password = hash_password(password, salt); return salt + '$' + hashed_password; } -bool check_password(const std::string& password, +inline bool check_password(const std::string& password, const std::string& encoded_password) { std::string salt, hashed_password; std::string* a = &salt, * b = &hashed_password; @@ -126,7 +126,7 @@ bool check_password(const std::string& password, // https://stackoverflow.com/questions/54060359/encoding-decoded-urls-in-c // there can be exceptions (std::stoi)! -std::string decode_url(const std::string& s) { +inline std::string decode_url(const std::string& s) { std::string r; for (std::size_t i = 0; i < s.length(); ++i) { if (s[i] == '%') { @@ -139,7 +139,7 @@ std::string decode_url(const std::string& s) { return r; } -std::string encode_url(const std::string& s) { +inline std::string encode_url(const std::string& s) { std::ostringstream oss; for (auto& c : s) { if (internal::url_safe_characters.find(c) != std::string::npos) { @@ -156,6 +156,7 @@ std::string encode_url(const std::string& s) { // where '&' can be any delimiter. // ki and vi will be converted if they are percent-encoded, // which is why the returned values are `string`, not `string_view`. +inline std::pair< std::map, std::map>> @@ -218,6 +219,7 @@ parse_params(std::string& s, std::size_t start_pos = 0, char delimiter = '&') { // this function will convert ki and vi if they are percent-encoded. // NOTE: don't misuse this function, it's going to modify // the parameter `s` in place! +inline std::tuple, std::map>> diff --git a/bserv/websocket.hpp b/bserv/websocket.hpp new file mode 100644 index 0000000..d9ca245 --- /dev/null +++ b/bserv/websocket.hpp @@ -0,0 +1,65 @@ +#ifndef _WEBSOCKET_HPP +#define _WEBSOCKET_HPP + +#include +#include +#include + +#include +#include +#include +#include + +namespace bserv { + +namespace beast = boost::beast; +namespace http = beast::http; +namespace websocket = beast::websocket; +namespace asio = boost::asio; +namespace json = boost::json; +using asio::ip::tcp; + +class websocket_closed + : public std::exception { +public: + websocket_closed() {} + const char* what() const noexcept { return "websocket session has been closed"; } +}; + +class websocket_io_exception + : public std::exception { +private: + const std::string msg_; +public: + websocket_io_exception(const std::string& msg) : msg_{msg} {} + const char* what() const noexcept { return msg_.c_str(); } +}; + +struct websocket_session { + const std::string address_; + asio::io_context& ioc_; + websocket::stream ws_; + websocket_session( + const std::string& address, + asio::io_context& ioc, + tcp::socket&& socket) + : address_{address}, + ioc_{ioc}, ws_{std::move(socket)} {} +}; + +class websocket_server { +private: + websocket_session& session_; + asio::yield_context& yield_; +public: + websocket_server(websocket_session& session, asio::yield_context& yield) + : session_{session}, yield_{yield} {} + std::string read(); + boost::json::value read_json() { return boost::json::parse(read()); } + void write(const std::string& data); + void write_json(const boost::json::value& val) { write(boost::json::serialize(val)); } +}; + +} // bserv + +#endif // _WEBSOCKET_HPP \ No newline at end of file diff --git a/handlers.hpp b/handlers.hpp index 9a21db3..380fe69 100644 --- a/handlers.hpp +++ b/handlers.hpp @@ -1,7 +1,7 @@ #ifndef _HANDLERS_HPP #define _HANDLERS_HPP -#include +#include #include #include @@ -209,11 +209,14 @@ boost::json::object user_logout( }; } -boost::json::object send_request(std::shared_ptr client_ptr) { +boost::json::object send_request( + std::shared_ptr session, + std::shared_ptr client_ptr, + boost::json::object&& params) { // post for response: // auto res = client_ptr->post( // "localhost", "8080", "/echo", {{"msg", "request"}} - // ).get(); + // ); // return {{"response", boost::json::parse(res.body())}}; // ------------------------------------------------------- // - if it takes longer than 30 seconds (by default) to @@ -221,15 +224,35 @@ boost::json::object send_request(std::shared_ptr client_ptr) // ------------------------------------------------------- // post for json response (json value, rather than json // object, is returned): - auto obj = client_ptr->post_for_object( - "localhost", "8080", "/echo", {{"msg", "request"}} - ).get(); - return {{"response", obj}}; + auto obj = client_ptr->post_for_value( + "localhost", "8080", "/echo", {{"request", params}} + ); + if (session->count("cnt") == 0) { + (*session)["cnt"] = 0; + } + (*session)["cnt"] = (*session)["cnt"].as_int64() + 1; + return {{"response", obj}, {"cnt", (*session)["cnt"]}}; } boost::json::object echo( boost::json::object&& params) { - return params; + return {{"echo", params}}; +} + +// websocket +std::nullopt_t ws_echo( + std::shared_ptr session, + std::shared_ptr ws_server) { + ws_server->write_json((*session)["cnt"]); + while (true) { + try { + std::string data = ws_server->read(); + ws_server->write(data); + } catch (bserv::websocket_closed&) { + break; + } + } + return std::nullopt; } #endif // _HANDLERS_HPP \ No newline at end of file diff --git a/main.cpp b/main.cpp index f649bab..4ed9d9a 100644 --- a/main.cpp +++ b/main.cpp @@ -120,9 +120,15 @@ int main(int argc, char* argv[]) { bserv::placeholders::db_connection_ptr, bserv::placeholders::_1), bserv::make_path("/send", &send_request, - bserv::placeholders::http_client_ptr), + bserv::placeholders::session, + bserv::placeholders::http_client_ptr, + bserv::placeholders::json_params), bserv::make_path("/echo", &echo, bserv::placeholders::json_params) + }, { + bserv::make_path("/echo", &ws_echo, + bserv::placeholders::session, + bserv::placeholders::websocket_server_ptr) }}; return EXIT_SUCCESS; diff --git a/scripts/request_test.py b/scripts/request_test.py new file mode 100644 index 0000000..a0e6174 --- /dev/null +++ b/scripts/request_test.py @@ -0,0 +1,48 @@ +import uuid + +import requests + +from multiprocessing import Process + +from pprint import pprint + +from time import time + +# session = requests.session() +# pprint(session.post("http://localhost:8080/send", json={"id": "abc"}).json()) +# pprint(session.post("http://localhost:8080/send", json={"id": "def"}).json()) +# pprint(session.post("http://localhost:8080/send", json={"id": "ghi"}).json()) +# exit() + +P = 100 # number of concurrent processes +N = 10 # for each process, the number of sessions +R = 10 # for each session, the number of posts + +def test(i): + global C + # print(f'starting process {i}') + for _ in range(N): + session = requests.session() + for i in range(1, R + 1): + session_id = str(uuid.uuid4()) + if {'cnt': i, 'response': {'echo': {'request': {'id': session_id}}}} \ + != session.post("http://localhost:8080/send", json={"id": session_id}).json(): + print('test failed!') + # print(f'exiting process {i}') + +processes = [Process(target=test, args=(i, )) for i in range(P)] + +print('starting') + +start = time() + +for p in processes: + p.start() + +for p in processes: + p.join() + +end = time() + +print('test ended') +print('elapsed: ', end - start) diff --git a/scripts/ws_test.py b/scripts/ws_test.py new file mode 100644 index 0000000..9cc440b --- /dev/null +++ b/scripts/ws_test.py @@ -0,0 +1,62 @@ +import asyncio +from multiprocessing import Process + +import requests +import websockets + +import random +import uuid +from time import time + +from pprint import pprint + +P = 500 + +def test(): + + async def fun(uri): + session = requests.session() + n = random.randint(1, 5) + # print(n) + for i in range(1, n + 1): + session_id = str(uuid.uuid4()) + ret = session.post("http://localhost:8080/send", json={"id": session_id}).json() + # print(ret) + # print({'cnt': i, 'response': {'echo': {'request': {'id': session_id}}}}) + if {'cnt': i, 'response': {'echo': {'request': {'id': session_id}}}} != ret: + print('post test failed!') + sess_id = session.cookies['bsessionid'] + # print("session id:", sess_id) + async with websockets.connect(uri, extra_headers={'Cookie': f"bsessionid={sess_id}"}) as websocket: + cnt = int(await websocket.recv()) + # print(cnt) + if cnt != n: + print('incorrect cnt') + m = random.randint(5, 10) + for _ in range(m): + session_id = str(uuid.uuid4()) + await websocket.send(session_id) + ret = await websocket.recv() + if session_id != ret: + print('ws test failed') + + asyncio.get_event_loop().run_until_complete( + fun("ws://localhost:8080/echo") + ) + +processes = [Process(target=test) for _ in range(P)] + +print('starting') + +start = time() + +for p in processes: + p.start() + +for p in processes: + p.join() + +end = time() + +print('test ended') +print('elapsed: ', end - start)