Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 80 additions & 32 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
bool ignore_cout) {
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
auto signal_handler = +[](int sig) -> void {
std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n";;
std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n";
shutdown_signal = true;
};
signal(SIGINT, signal_handler);
Expand Down Expand Up @@ -288,54 +288,102 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
return false;
};

auto handle_cors = [config_service](const drogon::HttpRequestPtr& req,
const drogon::HttpResponsePtr& resp) {
const std::string& origin = req->getHeader("Origin");
CTL_INF("Origin: " << origin);

auto allowed_origins =
config_service->GetApiServerConfiguration()->allowed_origins;

auto is_contains_asterisk =
std::find(allowed_origins.begin(), allowed_origins.end(), "*");
if (is_contains_asterisk != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
resp->addHeader("Access-Control-Allow-Methods", "*");
return;
}

// Check if the origin is in our allowed list
auto it = std::find(allowed_origins.begin(), allowed_origins.end(), origin);
if (it != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", origin);
} else if (allowed_origins.empty()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
}
resp->addHeader("Access-Control-Allow-Methods", "*");
};

drogon::app().registerPreRoutingAdvice(
[&validate_api_key](
[&validate_api_key, &handle_cors](
const drogon::HttpRequestPtr& req,
std::function<void(const drogon::HttpResponsePtr&)>&& cb,
drogon::AdviceChainCallback&& ccb) {
std::function<void(const drogon::HttpResponsePtr&)>&& stop,
drogon::AdviceChainCallback&& pass) {
// Handle OPTIONS preflight requests
if (req->method() == drogon::HttpMethod::Options) {
auto resp = HttpResponse::newHttpResponse();
auto handlers = drogon::app().getHandlersInfo();
bool has_ep = [req, &handlers]() {
for (auto const& h : handlers) {
if (req->path() == std::get<0>(h))
return true;
}
return false;
}();
if (!has_ep) {
resp->setStatusCode(drogon::HttpStatusCode::k404NotFound);
stop(resp);
return;
}

handle_cors(req, resp);
std::string supported_methods = [req, &handlers]() {
std::string methods;
for (auto const& h : handlers) {
if (req->path() == std::get<0>(h)) {
methods += drogon::to_string_view(std::get<1>(h));
methods += ", ";
}
}
if (methods.size() < 2)
return std::string();
return methods.substr(0, methods.size() - 2);
}();

// Add more info to header
resp->addHeader("Access-Control-Allow-Methods", supported_methods);
{
const auto& val = req->getHeader("Access-Control-Request-Headers");
if (!val.empty())
resp->addHeader("Access-Control-Allow-Headers", val);
}
// Set Access-Control-Max-Age
resp->addHeader("Access-Control-Max-Age",
"600"); // Cache for 10 minutes
stop(resp);
return;
}

if (!validate_api_key(req)) {
Json::Value ret;
ret["message"] = "Invalid API Key";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k401Unauthorized);
cb(resp);
stop(resp);
return;
}
ccb();
pass();
});

// CORS
drogon::app().registerPostHandlingAdvice(
[config_service](const drogon::HttpRequestPtr& req,
const drogon::HttpResponsePtr& resp) {
[config_service, &handle_cors](const drogon::HttpRequestPtr& req,
const drogon::HttpResponsePtr& resp) {
if (!config_service->GetApiServerConfiguration()->cors) {
CTL_INF("CORS is disabled!");
return;
}

const std::string& origin = req->getHeader("Origin");
CTL_INF("Origin: " << origin);

auto allowed_origins =
config_service->GetApiServerConfiguration()->allowed_origins;

auto is_contains_asterisk =
std::find(allowed_origins.begin(), allowed_origins.end(), "*");
if (is_contains_asterisk != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
resp->addHeader("Access-Control-Allow-Methods", "*");
return;
}

// Check if the origin is in our allowed list
auto it =
std::find(allowed_origins.begin(), allowed_origins.end(), origin);
if (it != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", origin);
} else if (allowed_origins.empty()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
}
resp->addHeader("Access-Control-Allow-Methods", "*");
handle_cors(req, resp);
});

// ssl
Expand Down
Loading