From ea0d3220db995018335c48eb06b9794235ff436b Mon Sep 17 00:00:00 2001 From: MetroWind Date: Sun, 7 Sep 2025 09:42:33 -0700 Subject: Initial commit, mostly just copied from shrt. --- src/app.cpp | 342 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 src/app.cpp (limited to 'src/app.cpp') diff --git a/src/app.cpp b/src/app.cpp new file mode 100644 index 0000000..7cfcf0d --- /dev/null +++ b/src/app.cpp @@ -0,0 +1,342 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "app.hpp" +#include "config.hpp" +#include "data.hpp" +#include "mw/error.hpp" + +namespace +{ + +std::unordered_map parseCookies(std::string_view value) +{ + std::unordered_map cookies; + size_t begin = 0; + while(true) + { + if(begin >= value.size()) + { + break; + } + + size_t semicolon = value.find(';', begin); + if(semicolon == std::string::npos) + { + semicolon = value.size(); + } + + std::string_view section = value.substr(begin, semicolon - begin); + + begin = semicolon + 1; + // Skip spaces + while(begin < value.size() && value[begin] == ' ') + { + begin++; + } + + size_t equal = section.find('='); + if(equal == std::string::npos) continue; + cookies.emplace(section.substr(0, equal), + section.substr(equal+1, semicolon - equal - 1)); + if(semicolon >= value.size()) + { + continue; + } + } + return cookies; +} + +void setTokenCookies(const mw::Tokens& tokens, App::Response& res) +{ + int64_t expire_sec = 300; + if(tokens.expiration.has_value()) + { + auto expire = std::chrono::duration_cast( + *tokens.expiration - mw::Clock::now()); + expire_sec = expire.count(); + } + res.set_header("Set-Cookie", std::format( + "shrt-access-token={}; Max-Age={}", + mw::urlEncode(tokens.access_token), expire_sec)); + // Add refresh token to cookie, with one month expiration. + if(tokens.refresh_token.has_value()) + { + expire_sec = 1800; + if(tokens.refresh_expiration.has_value()) + { + auto expire = std::chrono::duration_cast( + *tokens.refresh_expiration - mw::Clock::now()); + expire_sec = expire.count(); + } + + res.set_header("Set-Cookie", std::format( + "shrt-refresh-token={}; Max-Age={}", + mw::urlEncode(*tokens.refresh_token), expire_sec)); + } +} + +mw::HTTPServer::ListenAddress listenAddrFromConfig(const Configuration& config) +{ + if(config.listen_port == 0) + { + mw::SocketFileInfo sock(config.listen_address); + sock.user = config.socket_user; + sock.group = config.socket_group; + sock.permission = config.socket_permission; + return sock; + } + + mw::IPSocketInfo sock; + sock.address = config.listen_address; + sock.port = config.listen_port; + return sock; +} + +} // namespace + +App::App(const Configuration& conf, + std::unique_ptr data_source, + std::unique_ptr openid_auth) + : mw::HTTPServer(listenAddrFromConfig(conf)), + config(conf), + templates((std::filesystem::path(config.data_dir) / "templates" / "") + .string()), + data(std::move(data_source)), + auth(std::move(openid_auth)) +{ + auto u = mw::URL::fromStr(conf.base_url); + if(u.has_value()) + { + base_url = *std::move(u); + } + + templates.add_callback("url_for", [&](const inja::Arguments& args) -> + std::string + { + switch(args.size()) + { + case 1: + return urlFor(args.at(0)->get_ref()); + case 2: + return urlFor(args.at(0)->get_ref(), + args.at(1)->get_ref()); + default: + return "Invalid number of url_for() arguments"; + } + }); +} + +std::string App::urlFor(const std::string& name, const std::string& arg) const +{ + if(name == "statics") + { + return mw::URL(base_url).appendPath("_/statics").appendPath(arg).str(); + } + if(name == "index") + { + return base_url.str(); + } + if(name == "shortcut") + { + return mw::URL(base_url).appendPath(arg).str(); + } + if(name == "links") + { + return mw::URL(base_url).appendPath("_/links").str(); + } + if(name == "login") + { + return mw::URL(base_url).appendPath("_/login").str(); + } + if(name == "openid-redirect") + { + return mw::URL(base_url).appendPath("_/openid-redirect").str(); + } + if(name == "new-link") + { + return mw::URL(base_url).appendPath("_/new-link").str(); + } + if(name == "create-link") + { + return mw::URL(base_url).appendPath("_/create-link").str(); + } + if(name == "delete-link-dialog") + { + return mw::URL(base_url).appendPath("_/delete-link").appendPath(arg) + .str(); + } + if(name == "delete-link") + { + return mw::URL(base_url).appendPath("_/delete-link").str(); + } + + return ""; +} + +void App::handleIndex(Response& res) const +{ + res.set_redirect(urlFor("links"), 301); +} + +void App::handleLogin(Response& res) const +{ + res.set_redirect(auth->initialURL(), 301); +} + +void App::handleOpenIDRedirect(const Request& req, Response& res) const +{ + if(req.has_param("error")) + { + res.status = 500; + if(req.has_param("error_description")) + { + res.set_content( + std::format("{}: {}.", req.get_param_value("error"), + req.get_param_value("error_description")), + "text/plain"); + } + return; + } + else if(!req.has_param("code")) + { + res.status = 500; + res.set_content("No error or code in auth response", "text/plain"); + return; + } + + std::string code = req.get_param_value("code"); + spdlog::debug("OpenID server visited {} with code {}.", req.path, code); + ASSIGN_OR_RESPOND_ERROR(mw::Tokens tokens, auth->authenticate(code), res); + ASSIGN_OR_RESPOND_ERROR(mw::UserInfo user, auth->getUser(tokens), res); + + setTokenCookies(tokens, res); + res.set_redirect(urlFor("index"), 301); +} + + +std::string App::getPath(const std::string& name, + const std::string& arg_name) const +{ + return mw::URL::fromStr(urlFor(name, std::string(":") + arg_name)).value() + .path(); +} + +void App::setup() +{ + { + std::string statics_dir = (std::filesystem::path(config.data_dir) / + "statics").string(); + spdlog::info("Mounting static dir at {}...", statics_dir); + if (!server.set_mount_point("/_/statics", statics_dir)) + { + spdlog::error("Failed to mount statics"); + return; + } + } + + server.Get(getPath("index"), [&]([[maybe_unused]] const Request& req, Response& res) + { + handleIndex(res); + }); + server.Get(getPath("login"), [&]([[maybe_unused]] const Request& req, Response& res) + { + handleLogin(res); + }); + server.Get(getPath("openid-redirect"), [&](const Request& req, Response& res) + { + handleOpenIDRedirect(req, res); + }); +} + +mw::E App::validateSession(const Request& req) const +{ + if(!req.has_header("Cookie")) + { + spdlog::debug("Request has no cookie."); + return SessionValidation::invalid(); + } + + auto cookies = parseCookies(req.get_header_value("Cookie")); + if(auto it = cookies.find("shrt-access-token"); + it != std::end(cookies)) + { + spdlog::debug("Cookie has access token."); + mw::Tokens tokens; + tokens.access_token = it->second; + mw::E user = auth->getUser(tokens); + if(user.has_value()) + { + return SessionValidation::valid(*std::move(user)); + } + } + // No access token or access token expired + if(auto it = cookies.find("shrt-refresh-token"); + it != std::end(cookies)) + { + spdlog::debug("Cookie has refresh token."); + // Try to refresh the tokens. + ASSIGN_OR_RETURN(mw::Tokens tokens, auth->refreshTokens(it->second)); + ASSIGN_OR_RETURN(mw::UserInfo user, auth->getUser(tokens)); + return SessionValidation::refreshed(std::move(user), std::move(tokens)); + } + return SessionValidation::invalid(); +} + +std::optional App::prepareSession( + const Request& req, Response& res, bool allow_error_and_invalid) const +{ + mw::E session = validateSession(req); + if(!session.has_value()) + { + if(allow_error_and_invalid) + { + return SessionValidation::invalid(); + } + else + { + res.status = 500; + res.set_content("Failed to validate session.", "text/plain"); + return std::nullopt; + } + } + + switch(session->status) + { + case SessionValidation::INVALID: + if(allow_error_and_invalid) + { + return *session; + } + else + { + res.status = 401; + res.set_content("Invalid session.", "text/plain"); + return std::nullopt; + } + case SessionValidation::VALID: + break; + case SessionValidation::REFRESHED: + setTokenCookies(session->new_tokens, res); + break; + } + return *session; +} -- cgit v1.2.3-70-g09d2