"
+ "Thank you for using the AWS Advanced ODBC Wrapper! You can now close this window.
";
+
+const std::string InvalidResponse =
+ "HTTP/1.1 400 Bad Request\r\n"
+ "Content-Length: 95\r\n"
+ "Connection: close\r\n"
+ "Content-Type: text/html; charset=utf-8\r\n\r\n"
+ "The request could not be understood by the server!
";
diff --git a/driver/plugin/federated/http/Parser.cpp b/driver/plugin/federated/http/Parser.cpp
new file mode 100755
index 0000000..63195b4
--- /dev/null
+++ b/driver/plugin/federated/http/Parser.cpp
@@ -0,0 +1,174 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "../../../util/logger_wrapper.h"
+#include "Parser.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+void Parser::ParseRequestLine(std::string& str)
+{
+ std::istringstream istr(str);
+ std::vector command(
+ (std::istream_iterator(istr)),
+ std::istream_iterator()
+ );
+
+ if ((command.size() == 3) && (command[0].find(METHOD) != std::string::npos) &&
+ (command[2].find(HTTP_VERSION) != std::string::npos)) {
+ parser_state_ = STATE::PARSE_HEADER;
+ } else if ((command.size() == 3) &&
+ (command[0].find(GET_METHOD) != std::string::npos) &&
+ command[2].find(HTTP_VERSION) != std::string::npos) {
+ const size_t question_mark_pos = command[1].find('?');
+ if (question_mark_pos != std::string::npos) {
+ std::string parsed_body = command[1].substr(question_mark_pos + 1);
+ ParseBodyLine(parsed_body);
+ } else {
+ parser_state_ = STATE::PARSE_GET_REQUEST;
+ }
+ } else {
+ throw std::runtime_error("Request line contains wrong information.");
+ }
+}
+
+void Parser::ParseHeaderLine(std::string& str)
+{
+ if (str == "\r") {
+ parser_state_ = STATE::PARSE_BODY;
+
+ return;
+ }
+
+ const size_t ind = str.find(':');
+ header_size_ += str.size();
+
+ if ((ind == std::string::npos) || (header_size_ > MAX_HEADER_SIZE)) {
+ throw std::runtime_error("Received invalid header line.");
+ }
+
+ str.erase(std::remove_if(str.begin() + static_cast(ind), str.end(), ::isspace), str.end());
+
+ header_.insert({
+ std::string(str.begin(), str.begin() + static_cast(ind)),
+ std::string(str.begin() + static_cast(ind) + 1, str.end())
+ });
+}
+
+void Parser::ParseBodyLine(std::string& str)
+{
+ auto str_begin = str.begin();
+ auto str_end = str.end();
+
+ while (true) {
+ auto equal_it = find(str_begin, str_end, '=');
+ auto ampersand_it = find(str_begin, str_end, '&');
+
+ if (equal_it == str_end) {
+ throw std::runtime_error("Received invalid body line.");
+ }
+
+ body_.insert({
+ std::string(str_begin, equal_it),
+ std::string(equal_it + 1, ampersand_it)
+ });
+
+ if ((equal_it == str_end) || (ampersand_it == str_end)) {
+ break;
+ }
+
+ str_begin = ampersand_it + 1;
+ }
+
+ parser_state_ = STATE::PARSE_FINISHED;
+}
+
+void Parser::ParsePostRequest(std::string& str)
+{
+ switch (parser_state_) {
+ case STATE::PARSE_REQUEST:
+ ParseRequestLine(str);
+ break;
+ case STATE::PARSE_HEADER:
+ ParseHeaderLine(str);
+ break;
+ case STATE::PARSE_BODY:
+ if (!header_.contains("Content-Type") || !header_.contains("Content-Length") ||
+ (header_["Content-Type"] != "application/x-www-form-urlencoded") ||
+ (header_["Content-Length"] != std::to_string(str.size())))
+ {
+ throw std::runtime_error("Can't start parsing body as header contains invalid information.");
+ }
+
+ ParseBodyLine(str);
+ break;
+ default:
+ break;
+ }
+}
+
+STATUS Parser::Parse(std::istream &in)
+{
+ std::string str;
+ bool is_line_received = false;
+
+ while (getline(in, str) && parser_state_ != STATE::PARSE_GET_REQUEST) {
+ is_line_received = true;
+
+ try {
+ ParsePostRequest(str);
+ } catch (std::exception& e) {
+ DLOG(INFO) << e.what();
+ break;
+ }
+ }
+
+ if (parser_state_ == STATE::PARSE_GET_REQUEST) {
+ parser_state_ = STATE::PARSE_REQUEST;
+ return STATUS::GET_SUCCESS;
+ }
+
+ if (parser_state_ != STATE::PARSE_FINISHED) {
+ return is_line_received ? STATUS::FAILED : STATUS::EMPTY_REQUEST;
+ }
+
+ return STATUS::SUCCEED;
+}
+
+Parser::Parser()
+ : parser_state_(STATE::PARSE_REQUEST)
+ , header_size_(0)
+{
+ ; // Do nothing.
+}
+
+bool Parser::IsFinished() const
+{
+ return parser_state_ == STATE::PARSE_FINISHED;
+}
+
+std::string Parser::RetrieveAuthCode(std::string& state)
+{
+ if (body_.contains("code"))
+ {
+ return body_["code"];
+ }
+
+ return "";
+}
diff --git a/driver/plugin/federated/http/Parser.h b/driver/plugin/federated/http/Parser.h
new file mode 100755
index 0000000..6b13934
--- /dev/null
+++ b/driver/plugin/federated/http/Parser.h
@@ -0,0 +1,114 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include
+
+enum class STATE
+{
+ PARSE_REQUEST,
+ PARSE_HEADER,
+ PARSE_BODY,
+ PARSE_FINISHED,
+ PARSE_GET_REQUEST
+};
+
+enum class STATUS
+{
+ SUCCEED,
+ FAILED,
+ EMPTY_REQUEST,
+ GET_SUCCESS
+};
+
+/*
+* This class is used to parse the HTTP POST request
+* and retrieve authorization code.
+*/
+class Parser {
+ public:
+
+ Parser();
+
+ ~Parser() = default;
+
+ /*
+ * Initiate request parsing.
+ *
+ * Return true if parse was successful, otherwise false.
+ */
+ STATUS Parse(std::istream &in);
+
+ /*
+ * Check if parser is finished to parse the POST request.
+ *
+ * Return true if parsing was successfully finished, otherwise false.
+ */
+ bool IsFinished() const;
+
+ /*
+ * Retrieve authorization code.
+ *
+ * Return received authorization code or empty string.
+ */
+ std::string RetrieveAuthCode(std::string& state);
+
+ private:
+ /*
+ * Parse received POST request line by line.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParsePostRequest(std::string& str);
+
+ /*
+ * Parse request-line and perform verification for:
+ * method, Request-URI and HTTP-Version.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParseRequestLine(std::string& str);
+
+ /*
+ * Parse request header line.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParseHeaderLine(std::string& str);
+
+ /*
+ * Parse request body line in application/x-www-form-urlencoded format.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParseBodyLine(std::string& str);
+
+ /*
+ * Expected METHOD, URI and HTTP VERSION in request line.
+ */
+ const std::string METHOD = "POST";
+ // const std::string URI = "/redshift/";
+ const std::string HTTP_VERSION = "HTTP/1.1";
+ const int MAX_HEADER_SIZE = 8192;
+
+ const std::string GET_METHOD = "GET";
+ const std::string PKCE_URI = "/?code=";
+
+ STATE parser_state_;
+ size_t header_size_;
+ std::unordered_map header_;
+ std::unordered_map body_;
+
+};
diff --git a/driver/plugin/federated/http/Selector.cpp b/driver/plugin/federated/http/Selector.cpp
new file mode 100644
index 0000000..5dd7f79
--- /dev/null
+++ b/driver/plugin/federated/http/Selector.cpp
@@ -0,0 +1,56 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "Selector.h"
+
+void Selector::Register(SOCKET sfd)
+{
+ if (sfd == -1) {
+ return;
+ }
+
+ FD_SET(sfd, &master_fds_);
+
+#ifdef WIN32
+ max_fd_ = max(max_fd_, sfd);
+#else
+ max_fd_ = std::max(max_fd_, sfd);
+#endif
+}
+
+void Selector::Unregister(SOCKET sfd)
+{
+ FD_CLR(sfd, &master_fds_);
+}
+
+bool Selector::Select(struct timeval *tv)
+{
+ // As select will modify the file descriptor set we should keep temporary set to reflect the ready fd.
+ fd_set read_fds = master_fds_;
+
+ if (select(max_fd_ + 1, &read_fds, nullptr, nullptr, tv) > 0) {
+ for (SOCKET sfd = 0; sfd <= max_fd_; sfd++) {
+ if (FD_ISSET(sfd, &read_fds)) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+Selector::Selector() : max_fd_(0)
+{
+ FD_ZERO(&master_fds_);
+}
diff --git a/driver/plugin/federated/http/Selector.h b/driver/plugin/federated/http/Selector.h
new file mode 100755
index 0000000..d48a7de
--- /dev/null
+++ b/driver/plugin/federated/http/Selector.h
@@ -0,0 +1,52 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "SocketSupport.h"
+
+#include
+
+/*
+* This class is used to support insertion/deletion operations on
+* master descriptor set and selecting incoming connections.
+*/
+class Selector
+{
+ public:
+ Selector();
+
+ ~Selector() = default;
+
+ /*
+ * Update master descriptor set with new socket.
+ */
+ void Register(SOCKET sfd);
+
+ /*
+ * Remove socket from master descriptor set.
+ */
+ void Unregister(SOCKET sfd);
+
+ /*
+ * If any incoming connections is readable return true, or
+ * false otherwise.
+ */
+ bool Select(struct timeval* tv);
+
+ private:
+
+ SOCKET max_fd_;
+ fd_set master_fds_;
+};
diff --git a/driver/plugin/federated/http/Socket.cpp b/driver/plugin/federated/http/Socket.cpp
new file mode 100755
index 0000000..328f7f9
--- /dev/null
+++ b/driver/plugin/federated/http/Socket.cpp
@@ -0,0 +1,214 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "AddrInformation.h"
+#include "Socket.h"
+
+#include
+#include
+#include
+
+int Socket::Receive(char *buffer, int length, int flags) const
+{
+ int nbytes = 0;
+ int filled = 0;
+ auto start = std::chrono::system_clock
+ ::now();
+
+ const int receive_wait = 200;
+ // Give a chance to fully receive packet in case if there is no data in
+ // non-blocking socket.
+ while ((std::chrono::system_clock::now() - start < std::chrono::seconds(1))) {
+ nbytes = static_cast(recv(socket_fd_, buffer + filled, length - filled - 1, 0));
+
+ if (nbytes <= 0) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(receive_wait));
+ } else {
+ filled += nbytes;
+ }
+ }
+
+ return filled == 0 ? nbytes : filled;
+}
+
+int Socket::Send(const char *buffer, int length, int flags) const
+ {
+ int nbytes = 0;
+ int sent = 0;
+
+ while ((nbytes = static_cast(send(socket_fd_, buffer + sent, length - sent, flags))) > 0) {
+ sent += nbytes;
+ }
+
+ return sent == 0 ? nbytes : sent;
+}
+
+void Socket::PrepareListenSocket(const std::string& port)
+{
+ const AddrInformation addr_info(port);
+
+ for (const auto& ptr : addr_info) {
+ socket_fd_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
+
+ if (socket_fd_ == INVALID_SOCKET) {
+ continue;
+ }
+
+ if (!SetReusable()) {
+ Close();
+
+ throw std::runtime_error("Unable to reuse the address.");
+ }
+
+ if (Bind(ptr->ai_addr, ptr->ai_addrlen) == 0) {
+ break;
+ }
+
+ Close();
+ }
+
+ if (!SetNonBlocking()) {
+ Close();
+
+ throw std::runtime_error("Unable to set socket to non-blocking mode.");
+ }
+
+ if ((socket_fd_ == INVALID_SOCKET) || (Listen(CONNECTION_BACKLOG))) {
+ Close();
+
+ throw std::runtime_error("Can not start listening on port: " + port);
+ }
+}
+
+bool Socket::Close()
+{
+ if (socket_fd_ == -1) {
+ return false;
+ }
+
+#if (defined(_WIN32) || defined(_WIN64))
+ int res = closesocket(socket_fd_);
+#else
+ const int res = close(socket_fd_);
+#endif
+
+ socket_fd_ = INVALID_SOCKET;
+
+ return res == 0;
+}
+
+Socket::Socket() : socket_fd_(INVALID_SOCKET)
+{
+ ; // Do nothing.
+}
+
+Socket::Socket(SOCKET sfd) : socket_fd_(sfd)
+{
+ ; // Do nothing.
+}
+
+Socket::Socket(Socket&& s)
+{
+ socket_fd_ = std::move(s.socket_fd_);
+
+ s.socket_fd_ = INVALID_SOCKET;
+}
+
+Socket& Socket::operator=(Socket&& s)
+{
+ socket_fd_ = std::move(s.socket_fd_);
+
+ s.socket_fd_ = INVALID_SOCKET;
+
+ return *this;
+}
+
+Socket::~Socket()
+{
+ Close();
+}
+
+bool Socket::SetNonBlocking() const
+{
+#if (defined(_WIN32) || defined(_WIN64))
+ unsigned long mode = 1;
+ return ioctlsocket(socket_fd_, FIONBIO, &mode) == 0 ? true : false;
+#else
+ int flags = fcntl(socket_fd_, F_GETFL, 0);
+ return fcntl(socket_fd_, F_SETFL, flags | O_NONBLOCK) == 0 ? true : false;
+#endif
+}
+
+bool Socket::IsNonBlockingError() const
+{
+#if (defined(_WIN32) || defined(_WIN64))
+ return WSAGetLastError() == WSAEWOULDBLOCK;
+#else
+ return errno == EWOULDBLOCK;
+#endif
+}
+
+void Socket::Register(Selector& selector) const
+{
+ selector.Register(socket_fd_);
+}
+
+void Socket::Unregister(Selector& selector) const
+{
+ selector.Unregister(socket_fd_);
+}
+
+int Socket::Bind(const struct sockaddr *address, size_t address_len) const
+{
+ return bind(socket_fd_, address, (int)address_len);
+}
+
+int Socket::GetListenPort() const
+{
+ sockaddr_storage addr;
+ socklen_t len = sizeof(addr);
+ int port = 0;
+
+ getsockname(socket_fd_, (struct sockaddr*)&addr, &len);
+ port = htons(((sockaddr_in*)&addr)->sin_port);
+
+ return port;
+}
+
+int Socket::Listen(int backlog) const
+{
+ return listen(socket_fd_, backlog);
+}
+
+Socket Socket::Accept() const
+{
+ sockaddr_storage remoteaddr;
+ socklen_t addrlen = sizeof(remoteaddr);
+
+ return Socket(accept(socket_fd_, (struct sockaddr*)&remoteaddr, &addrlen));
+}
+
+bool Socket::SetReusable() const
+{
+ int yes = 1;
+
+ // Windows: int setsockopt(SOCKET s, int level, int optname, const char *optval, int optlen);
+ // The optval parameter has ponter to const char, but according to the MSDN we should use int:
+ // To enable a Boolean option, the optval parameter points to a nonzero integer.
+ // To disable the option optval points to an integer equal to zero.
+ // The optlen parameter should be equal to sizeof(int) for Boolean options.
+ // On Linux the optval parameter is pointer to void.
+ return setsockopt(socket_fd_, SOL_SOCKET, SO_REUSEADDR,
+ (char *)&yes, sizeof(yes)) == 0;
+}
diff --git a/driver/plugin/federated/http/Socket.h b/driver/plugin/federated/http/Socket.h
new file mode 100755
index 0000000..2d4e465
--- /dev/null
+++ b/driver/plugin/federated/http/Socket.h
@@ -0,0 +1,126 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOCKET_H_
+#define SOCKET_H_
+
+#pragma once
+
+#include "SocketSupport.h"
+#include "Selector.h"
+#include
+
+/*
+* This class is used to wrap the network socket
+* in order to have cross-platfrom code to send and receive
+* data from incoming connections.
+*/
+class Socket
+{
+ public:
+ Socket();
+
+ Socket( SOCKET sfd);
+
+ Socket(const Socket& s) = delete;
+
+ Socket& operator=(const Socket& s) = delete;
+
+ Socket(Socket&& s);
+
+ Socket& operator=(Socket&& s);
+
+ ~Socket();
+
+ /*
+ * Close a socket file descriptor, return true on success
+ * or false in case of error.
+ */
+ bool Close();
+
+ /*
+ * Get port where socket is listening.
+ */
+ int GetListenPort() const;
+
+ /*
+ * Listen for connections on a socket and return zero on success,
+ * or -1 in case of error.
+ */
+ int Listen(int backlog) const;
+
+ /*
+ * Accept a connection on a socket and return StreamSocket object.
+ */
+ Socket Accept() const;
+
+ /*
+ * Forcibly bind to a port in use by another socket and return true on success,
+ * or false in case of error.
+ */
+ bool SetReusable() const;
+
+ /*
+ * Set socket to non-blocking mode and return true on success
+ * or false in case of error.
+ */
+ bool SetNonBlocking() const;
+
+ /*
+ * Return true if error is caused by non-blocking mode of socket
+ * otherwise return false.
+ */
+ bool IsNonBlockingError() const;
+
+ /*
+ * Prepare socket to handle incoming connections.
+ */
+ void PrepareListenSocket(const std::string& port);
+
+ /*
+ * Register socket in master file descriptor set using Selector class.
+ */
+ void Register(Selector& selector) const;
+
+ /*
+ * Clear socket in master file descriptor set using Selector class.
+ */
+ void Unregister(Selector& selector) const;
+
+ /*
+ * Receive a message from a socket and return the number of bytes received,
+ * or -1 if an error occurred.
+ */
+ int Receive(char *buffer, int length, int flags) const;
+
+ /*
+ * Send a message on a socket and return the number of bytes sent,
+ * or -1 if an error occured.
+ */
+ int Send(const char *buffer, int length, int flags) const;
+
+ /*
+ * Bind a name to a socket and return zero on success,
+ * or -1 in case of error.
+ */
+ int Bind(const struct sockaddr *address, size_t address_len) const;
+
+ private:
+
+ const int CONNECTION_BACKLOG = 10;
+
+ SOCKET socket_fd_;
+};
+
+#endif // SOCKET_H_
diff --git a/driver/plugin/federated/http/SocketStream.cpp b/driver/plugin/federated/http/SocketStream.cpp
new file mode 100644
index 0000000..1d41a3c
--- /dev/null
+++ b/driver/plugin/federated/http/SocketStream.cpp
@@ -0,0 +1,50 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "SocketStream.h"
+
+SocketStream::SocketStream(Socket& s) : received_size_(0), socket_(s)
+{
+ setg(input_buffer_, input_buffer_, input_buffer_);
+}
+
+SocketStream::~SocketStream()
+{
+ sync();
+}
+
+int SocketStream::underflow()
+{
+ if (gptr() < egptr()) {
+ return traits_type::to_int_type(*gptr());
+ }
+
+ int received_bytes = socket_.Receive(input_buffer_, SIZE - 1, 0);
+
+ /*
+ * Return EOF in the following cases:
+ * If the received bytes less than zero (error situation);
+ * If the received bytes equal to zero (socket peer has performed an orderly shutdown);
+ * If the length of received packets more than MAX_SIZE.
+ */
+ if ((received_bytes <= 0) || (received_size_ + received_bytes > MAX_SIZE)) {
+ return traits_type::eof();
+ }
+
+ received_size_ += received_bytes;
+
+ setg(input_buffer_, input_buffer_, input_buffer_ + received_bytes);
+
+ return traits_type::to_int_type(*gptr());
+}
\ No newline at end of file
diff --git a/driver/plugin/federated/http/SocketStream.h b/driver/plugin/federated/http/SocketStream.h
new file mode 100755
index 0000000..5d1eb7b
--- /dev/null
+++ b/driver/plugin/federated/http/SocketStream.h
@@ -0,0 +1,59 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOCKET_STREAM_H_
+#define SOCKET_STREAM_H_
+
+#pragma once
+
+#include "Socket.h"
+
+#include
+
+/*
+* This class is used to wrap socket by std::stream.
+*/
+class SocketStream : public std::streambuf
+{
+ public:
+ /*
+ * Construct object and set the value for the pointers that define
+ * the boundaries of the buffered portion of the controlled INPUT sequence.
+ */
+ SocketStream(Socket& s);
+
+ /*
+ * Call sync inside the destructor to synchronize the contents in the stream buffer
+ * with those of the associated character sequence.
+ */
+ virtual ~SocketStream();
+
+ protected:
+ /*
+ * Virtual function called by other member functions to get the current character
+ * in the controlled input sequence without changing the current position.
+ * It is called by public member functions such as sgetc to request
+ * a new character when there are no read positions available at the get pointer (gptr).
+ */
+ virtual int underflow();
+
+ static const int SIZE = 2048;
+ static const int MAX_SIZE = 16384;
+
+ int received_size_;
+ char input_buffer_[SIZE];
+ Socket& socket_;
+};
+
+#endif // SOCKET_STREAM_H_
diff --git a/driver/plugin/federated/http/SocketSupport.h b/driver/plugin/federated/http/SocketSupport.h
new file mode 100644
index 0000000..074c0d2
--- /dev/null
+++ b/driver/plugin/federated/http/SocketSupport.h
@@ -0,0 +1,35 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#if (defined(_WIN32) || defined(_WIN64))
+
+#include
+#include
+#include "Ws2tcpip.h"
+#pragma comment(lib, "Ws2_32.lib")
+
+#else /* Linux or MAC */
+
+#include
+#include
+#include
+#include
+#include
+
+typedef int SOCKET;
+#define INVALID_SOCKET -1
+
+#endif
diff --git a/driver/plugin/federated/http/WEBServer.cpp b/driver/plugin/federated/http/WEBServer.cpp
new file mode 100755
index 0000000..1b84cf5
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer.cpp
@@ -0,0 +1,153 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "../../../util/logger_wrapper.h"
+#include "HtmlResponse.h"
+#include "SocketStream.h"
+#include "WEBServer.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+WEBServer::WEBServer( std::string& state,
+ std::string& port, std::string& timeout) :
+ state_(state),
+ port_(port),
+ timeout_(std::stoi(timeout)),
+ selector_(),
+ parser_(),
+ listen_socket_(),
+ listen_port_(0),
+ connections_counter_(0),
+ listening_(false)
+{
+ ; // Do nothing.
+}
+
+void WEBServer::LaunchServer()
+{
+ // Create thread to launch the server to listen the incoming connection.
+ // Waiting for the redirect response from the /oauth2/authorize.
+ thread_ = std::thread(&WEBServer::ListenerThread, this);
+}
+
+void WEBServer::Join()
+{
+ if (thread_.joinable()) {
+ thread_.join();
+ }
+}
+
+std::string WEBServer::GetCode() const
+{
+ return code_;
+}
+
+int WEBServer::GetListenPort() const
+{
+ return listen_port_;
+}
+
+bool WEBServer::IsListening() const
+{
+ return listening_.load();
+}
+
+bool WEBServer::IsTimeout() const
+{
+ return connections_counter_ > 0 ? false : true;
+}
+
+void WEBServer::Cancel() {
+ cancel_ = true;
+}
+
+bool WEBServer::WEBServerInit()
+{
+ // Prepare the environment for get the socket description.
+ try {
+ listen_socket_.PrepareListenSocket(port_);
+ listen_socket_.Register(selector_);
+ listen_port_ = listen_socket_.GetListenPort();
+ } catch (std::exception& e) {
+ DLOG(INFO) << "Exception: " << e.what();
+
+ return false;
+ }
+
+ return true;
+}
+
+void WEBServer::HandleConnection()
+{
+ /* Trying to accept the pending incoming connection. */
+ Socket ssck(listen_socket_.Accept());
+
+ ++connections_counter_;
+
+ if (ssck.SetNonBlocking()) {
+ SocketStream socket_buffer(ssck);
+ std::istream socket_stream(&socket_buffer);
+
+ const STATUS status = parser_.Parse(socket_stream);
+
+ if (status == STATUS::SUCCEED) {
+ ssck.Send(ValidResponse.c_str(), static_cast(ValidResponse.size()), 0);
+ } else if (status == STATUS::FAILED) {
+ ssck.Send(InvalidResponse.c_str(), static_cast(InvalidResponse.size()), 0);
+ } else if (status == STATUS::GET_SUCCESS) {
+ const std::string response = std::vformat(GetResponse, std::make_format_args(listen_port_));
+ ssck.Send(response.c_str(), static_cast(response.size()), 0);
+ } else {
+ /* Nothing is received from socket. Continue to listen. */
+ return;
+ }
+ }
+}
+
+void WEBServer::Listen()
+{
+ // Set timeout for non-blocking socket to 1 sec to pass it to Select.
+ struct timeval tv = { .tv_sec=1, .tv_usec=0 };
+
+ if (selector_.Select(&tv)) {
+ HandleConnection();
+ }
+}
+
+void WEBServer::ListenerThread()
+{
+ if (!WEBServerInit()) {
+ DLOG(INFO) << "WEBServerInit Failed";
+ return;
+ }
+
+ auto start = std::chrono::system_clock::now();
+
+ listening_.store(true);
+
+ while ((std::chrono::system_clock::now() - start < std::chrono::seconds(timeout_)) && !parser_.IsFinished()) {
+ Listen();
+ }
+
+ code_ = parser_.RetrieveAuthCode(state_);
+
+ listen_socket_.Close();
+}
diff --git a/driver/plugin/federated/http/WEBServer.h b/driver/plugin/federated/http/WEBServer.h
new file mode 100755
index 0000000..4be1205
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer.h
@@ -0,0 +1,118 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef WEBSERVER_H_
+#define WEBSERVER_H_
+
+#pragma once
+
+#include "Parser.h"
+#include "Selector.h"
+#include "Socket.h"
+
+#include
+#include
+
+/*
+* This class is used to launch the HTTP WEB server in separate thread
+* to wait for the redirect response from the /oauth2/authorize and
+* extract the authorization code from it.
+*/
+class WEBServer
+{
+ public:
+ WEBServer(
+ std::string& state,
+ std::string& port,
+ std::string& timeout);
+
+ ~WEBServer() = default;
+
+ /*
+ * Launch the HTTP WEB server in separate thread.
+ */
+ void LaunchServer();
+
+ /*
+ * Wait until HTTP WEB server is finished.
+ */
+ void Join();
+
+ /*
+ * Extract the authorization code from response.
+ */
+ std::string GetCode() const;
+
+ /*
+ * Get port where server is listening.
+ */
+ int GetListenPort() const;
+
+ /*
+ * Extract the SAML Assertion from response.
+ */
+ std::string GetSamlResponse() const;
+
+ /*
+ * If server is listening for connections return true, otherwise return false.
+ */
+ bool IsListening() const;
+
+ /*
+ * If timeout happened return true, otherwise return false.
+ */
+ bool IsTimeout() const;
+
+ /*
+ * Cancel listen loop prematurely without waiting for the timeout.
+ */
+ void Cancel();
+
+ private:
+ /*
+ * Main HTTP WEB server function that perform initialization and
+ * listen for incoming connections for specified time by user.
+ */
+ void ListenerThread();
+
+ /*
+ * If incoming connection is available call HandleConnection.
+ */
+ void Listen();
+
+ /*
+ * Launch parser if incoming connection is acceptable.
+ */
+ void HandleConnection();
+
+ /*
+ * Perform socket preparation to launch the HTTP WEB server.
+ */
+ bool WEBServerInit();
+
+ std::string state_;
+ std::string port_;
+ int timeout_;
+ std::string code_;
+ std::thread thread_;
+ Selector selector_;
+ Parser parser_;
+ Socket listen_socket_;
+ int listen_port_;
+ int connections_counter_;
+ std::atomic listening_;
+ bool cancel_ = false;
+};
+
+#endif // WEBSERVER_H_
diff --git a/driver/plugin/federated/http/WEBServer_utils.cpp b/driver/plugin/federated/http/WEBServer_utils.cpp
new file mode 100644
index 0000000..49c4f2a
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer_utils.cpp
@@ -0,0 +1,40 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "WEBServer_utils.h"
+
+int WebServerUtils::GenerateRandomInteger(int low, int high)
+{
+ std::random_device rd;
+ std::mt19937 generator(rd());
+ std::uniform_int_distribution<> dist(low, high);
+
+ return dist(generator);
+}
+
+std::string WebServerUtils::GenerateState()
+{
+ const char STATE_CHAR_LIST[37] = "0123456789abcdefghijklmnopqrstuvwxyz";
+ const int chars_size = (sizeof(STATE_CHAR_LIST) / sizeof(*STATE_CHAR_LIST)) - 1;
+ const int rand_size = GenerateRandomInteger(9, chars_size - 1);
+ std::string state;
+
+ state.reserve(rand_size);
+
+ for (int i = 0; i < rand_size; ++i) {
+ state.push_back(STATE_CHAR_LIST[GenerateRandomInteger(0, rand_size)]);
+ }
+
+ return state;
+}
diff --git a/driver/plugin/federated/http/WEBServer_utils.h b/driver/plugin/federated/http/WEBServer_utils.h
new file mode 100644
index 0000000..d159738
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer_utils.h
@@ -0,0 +1,25 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef WEBSERVER_UTILS_H_
+#define WEBSERVER_UTILS_H_
+
+#include
+
+namespace WebServerUtils {
+ int GenerateRandomInteger(int low, int high);
+ std::string GenerateState();
+}
+
+#endif // WEBSERVER_UTILS_H_
diff --git a/driver/plugin/federated/okta_auth_plugin.cpp b/driver/plugin/federated/okta_auth_plugin.cpp
index 95c66ce..23028ee 100644
--- a/driver/plugin/federated/okta_auth_plugin.cpp
+++ b/driver/plugin/federated/okta_auth_plugin.cpp
@@ -12,17 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "okta_auth_plugin.h"
+#if (defined(_WIN32) || defined(_WIN64))
+ #include
+ #include
+ #undef GetObject
+#endif
-#include
+#include "okta_auth_plugin.h"
#include "html_util.h"
+#include "http/WEBServer.h"
+#include "http/WEBServer_utils.h"
+#include "saml_util.h"
-#include "../../util/aws_sdk_helper.h"
#include "../../util/connection_string_keys.h"
#include "../../util/logger_wrapper.h"
+#include "../../util/rds_strings.h"
#include "../../util/rds_utils.h"
-#include "saml_util.h"
OktaAuthPlugin::OktaAuthPlugin(DBC *dbc) : OktaAuthPlugin(dbc, nullptr) {}
@@ -113,7 +119,7 @@ SQLRETURN OktaAuthPlugin::Connect(
ret = next_plugin->Connect(ConnectionHandle, WindowHandle, OutConnectionString, BufferLength, StringLengthPtr, DriverCompletion);
// Unsuccessful connection using cached token
- // Skip cache and generate a new token to retry
+ // Skip cache and generate a new token to retry
if (!SQL_SUCCEEDED(ret) && token.second) {
LOG(WARNING) << "Cached token failed to connect. Retrying with fresh token";
// Update AWS Credentials
@@ -145,6 +151,16 @@ OktaSamlUtil::OktaSamlUtil(
}
sign_in_url = "https://" + idp_endpoint + ":" + idp_port + "/app/amazon_aws/" + app_id + "/sso/saml" + "?onetimetoken=";
session_token_url = "https://" + idp_endpoint + ":" + idp_port + "/api/v1/authn";
+
+ const std::string mfa_type_str = connection_attributes.contains(KEY_MFA_TYPE) ?
+ connection_attributes.at(KEY_MFA_TYPE) : "";
+ if (mfa_type_table.contains(mfa_type_str)) {
+ mfa_type = mfa_type_table.at(mfa_type_str);
+ }
+ mfa_port = connection_attributes.contains(KEY_MFA_PORT) ?
+ connection_attributes.at(KEY_MFA_PORT) : DEFAULT_PORT;
+ mfa_timeout = connection_attributes.contains(KEY_MFA_TIMEOUT) ?
+ connection_attributes.at(KEY_MFA_TIMEOUT) : DEFAULT_MFA_TIMEOUT;
}
std::string OktaSamlUtil::GetSamlAssertion()
@@ -214,10 +230,169 @@ std::string OktaSamlUtil::GetSessionToken()
LOG(ERROR) << "Unable to parse JSON from response";
return "";
}
+
+ if (mfa_type != NONE) {
+ const Aws::Utils::Json::JsonView json_view = json_val.View();
+
+ if (!json_view.KeyExists("stateToken")) {
+ LOG(ERROR) << "Could not find state token in JSON";
+ return "";
+ }
+
+ const std::string state_token = json_view.GetString("stateToken");
+ const Aws::Utils::Json::JsonView embedded_view = json_view.GetObject("_embedded");
+ Aws::Utils::Array factor_views = embedded_view.GetArray("factors");
+
+ std::string factor_id;
+ for (int i = 0; i < factor_views.GetLength(); i++) {
+ const std::string type = factor_views[i].GetString("factorType");
+ if (mfa_type == TOTP && type == "token:software:totp" || mfa_type == PUSH && type == "push") {
+ factor_id = factor_views[i].GetString("id");
+ }
+ }
+
+ if (factor_id.empty()) {
+ LOG(ERROR) << "Could not find factor in JSON";
+ return "";
+ }
+
+ const std::string verify_url = session_token_url + "/factors/" + factor_id + "/verify";
+ if (mfa_type == TOTP) {
+ return VerifyTOTPChallenge(verify_url, state_token);
+ }
+ if (mfa_type == PUSH) {
+ return VerifyPushChallenge(verify_url, state_token);
+ }
+ }
+
const Aws::Utils::Json::JsonView json_view = json_val.View();
if (!json_view.KeyExists("sessionToken")) {
LOG(ERROR) << "Could not find session token in JSON";
return "";
}
+
return json_view.GetString("sessionToken");
}
+
+std::string OktaSamlUtil::VerifyTOTPChallenge(
+ const std::string &verify_url,
+ const std::string &state_token)
+{
+ std::string state = WebServerUtils::GenerateState();
+ WEBServer srv(state, mfa_port, mfa_timeout);
+
+ srv.LaunchServer();
+
+ const std::string mfa_form_url = WEBSERVER_HOST + ":" + mfa_port;
+ try {
+#if (defined(_WIN32) || defined(_WIN64))
+ const HINSTANCE result = ShellExecute(NULL, RDS_TSTR(std::string("open")).c_str(), RDS_TSTR(mfa_form_url).c_str(), NULL, NULL, SW_SHOWNORMAL);
+ if (reinterpret_cast(result) <= 32) {
+ srv.Cancel();
+ }
+#else
+#if (defined(LINUX) || defined(__linux__))
+ const int result = system(("xdg-open " + mfa_form_url).c_str());
+#else
+ const int result = system(("open " + mfa_form_url).c_str());
+#endif
+ if (result != 0) {
+ srv.Cancel();
+ }
+#endif
+ } catch (const std::exception & e) {
+ srv.Cancel();
+ srv.Join();
+ LOG(ERROR) << "Could not open browser to obtain MFA token: " << e.what();
+ return "";
+ }
+
+ srv.Join();
+
+ const std::string pass_code = srv.GetCode();
+ if (pass_code.empty()) {
+ LOG(ERROR) << "MFA Authorization code was not obtained";
+ return "";
+ }
+
+ const std::shared_ptr req = Aws::Http::CreateHttpRequest(
+ verify_url, Aws::Http::HttpMethod::HTTP_POST, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
+ Aws::Utils::Json::JsonValue json_body;
+ json_body
+ .WithString("stateToken", state_token)
+ .WithString("passCode", pass_code);
+ const Aws::String json_str = json_body.View().WriteReadable();
+ const Aws::String json_len = Aws::Utils::StringUtils::to_string(json_str.size());
+ req->SetContentType("application/json");
+ req->AddContentBody(Aws::MakeShared("", json_str));
+ req->SetContentLength(json_len);
+ const std::shared_ptr response = http_client->MakeRequest(req);
+
+ // Check resp status
+ if (response->GetResponseCode() != Aws::Http::HttpResponseCode::OK) {
+ LOG(ERROR) << "OKTA request returned bad HTTP response code: " << response->GetResponseCode();
+ if (response->HasClientError()) {
+ LOG(ERROR) << "HTTP Client Error: " << response->GetClientErrorMessage();
+ }
+ return "";
+ }
+
+ const Aws::Utils::Json::JsonValue json_val(response->GetResponseBody());
+ if (!json_val.WasParseSuccessful()) {
+ LOG(ERROR) << "Unable to parse JSON from response";
+ return "";
+ }
+
+ const Aws::Utils::Json::JsonView json_view = json_val.View();
+ if (!json_view.KeyExists("sessionToken")) {
+ LOG(ERROR) << "Could not find session token in JSON";
+ return "";
+ }
+
+ return json_view.GetString("sessionToken");
+}
+
+std::string OktaSamlUtil::VerifyPushChallenge(
+ const std::string &verify_url,
+ const std::string &state_token)
+{
+ const std::chrono::time_point end_time = std::chrono::system_clock::now() + std::chrono::seconds(std::strtol(mfa_timeout.c_str(), nullptr, 0));
+ while (std::chrono::system_clock::now() < end_time) {
+ const std::shared_ptr req = Aws::Http::CreateHttpRequest(
+ verify_url, Aws::Http::HttpMethod::HTTP_POST, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
+ Aws::Utils::Json::JsonValue json_body;
+ json_body.WithString("stateToken", state_token);
+ const Aws::String json_str = json_body.View().WriteReadable();
+ const Aws::String json_len = Aws::Utils::StringUtils::to_string(json_str.size());
+ req->SetContentType("application/json");
+ req->AddContentBody(Aws::MakeShared("", json_str));
+ req->SetContentLength(json_len);
+ const std::shared_ptr response = http_client->MakeRequest(req);
+
+ // Check resp status
+ if (response->GetResponseCode() != Aws::Http::HttpResponseCode::OK) {
+ LOG(ERROR) << "OKTA request returned bad HTTP response code: " << response->GetResponseCode();
+ if (response->HasClientError()) {
+ LOG(ERROR) << "HTTP Client Error: " << response->GetClientErrorMessage();
+ }
+ } else {
+ const Aws::Utils::Json::JsonValue json_val(response->GetResponseBody());
+ if (!json_val.WasParseSuccessful()) {
+ LOG(ERROR) << "Unable to parse JSON from response";
+ continue;
+ }
+
+ const Aws::Utils::Json::JsonView json_view = json_val.View();
+ if (!json_view.KeyExists("sessionToken")) {
+ LOG(ERROR) << "Could not find session token in JSON";
+ continue;
+ }
+
+ return json_view.GetString("sessionToken");
+ }
+
+ std::this_thread::sleep_for(std::chrono::seconds(VERIFY_PUSH_INTERVAL));
+ }
+ LOG(ERROR) << "The MFA challenge was not completed in time";
+ return "";
+}
diff --git a/driver/plugin/federated/okta_auth_plugin.h b/driver/plugin/federated/okta_auth_plugin.h
index d0239a4..8f4d44c 100644
--- a/driver/plugin/federated/okta_auth_plugin.h
+++ b/driver/plugin/federated/okta_auth_plugin.h
@@ -21,6 +21,18 @@
#include "../base_plugin.h"
#include "../../driver.h"
+typedef enum {
+ NONE,
+ TOTP,
+ PUSH
+} MfaType;
+
+static std::map const mfa_type_table = {
+ {"", MfaType::NONE},
+ {VALUE_MFA_TOTP, MfaType::TOTP},
+ {VALUE_MFA_PUSH, MfaType::PUSH}
+};
+
class OktaSamlUtil : public SamlUtil {
public:
OktaSamlUtil(const std::map &connection_attributes);
@@ -30,8 +42,19 @@ class OktaSamlUtil : public SamlUtil {
private:
std::string GetSessionToken();
+ std::string VerifyTOTPChallenge(const std::string &verify_url, const std::string &state_token);
+
+ std::string VerifyPushChallenge(const std::string &verify_url, const std::string &state_token);
+
+ static inline const std::string DEFAULT_MFA_TIMEOUT = "60";
+ static inline const int VERIFY_PUSH_INTERVAL = 5;
+ static inline const std::string DEFAULT_PORT = "8080";
+ static inline const std::string WEBSERVER_HOST = "http://127.0.0.1";
std::string sign_in_url;
std::string session_token_url;
+ MfaType mfa_type;
+ std::string mfa_port;
+ std::string mfa_timeout;
static inline const std::regex SAML_RESPONSE_PATTERN = std::regex("");
};
diff --git a/driver/util/attribute_validator.cpp b/driver/util/attribute_validator.cpp
index 4b68941..275fa55 100644
--- a/driver/util/attribute_validator.cpp
+++ b/driver/util/attribute_validator.cpp
@@ -34,7 +34,9 @@ bool AttributeValidator::ShouldKeyBeUnsignedInt(const std::string& key) {
KEY_FAILOVER_TIMEOUT,
KEY_LIMITLESS_MONITOR_INTERVAL_MS,
KEY_ROUTER_MAX_RETRIES,
- KEY_LIMITLESS_MAX_RETRIES
+ KEY_LIMITLESS_MAX_RETRIES,
+ KEY_MFA_PORT,
+ KEY_MFA_TIMEOUT
};
return INTEGER_KEYS.contains(key);
}
diff --git a/driver/util/connection_string_helper.h b/driver/util/connection_string_helper.h
index f97676a..a607a0d 100644
--- a/driver/util/connection_string_helper.h
+++ b/driver/util/connection_string_helper.h
@@ -48,6 +48,9 @@ static std::unordered_set const aws_odbc_key_set = {
KEY_HTTP_CONNECT_TIMEOUT,
KEY_RELAY_PARTY_ID,
KEY_APP_ID,
+ KEY_MFA_TYPE,
+ KEY_MFA_PORT,
+ KEY_MFA_TIMEOUT,
KEY_DATABASE_DIALECT,
KEY_HOST_SELECTOR_STRATEGY,
KEY_ENABLE_FAILOVER,
diff --git a/driver/util/connection_string_keys.h b/driver/util/connection_string_keys.h
index 435d570..f716eb5 100644
--- a/driver/util/connection_string_keys.h
+++ b/driver/util/connection_string_keys.h
@@ -75,6 +75,11 @@
#define KEY_RELAY_PARTY_ID "RELAY_PARTY_ID"
/* OKTA */
#define KEY_APP_ID "APP_ID"
+#define KEY_MFA_TYPE "MFA_TYPE"
+#define KEY_MFA_PORT "MFA_PORT"
+#define KEY_MFA_TIMEOUT "MFA_TIMEOUT"
+#define VALUE_MFA_TOTP "TOTP"
+#define VALUE_MFA_PUSH "PUSH"
/* Host Selectors */
#define KEY_HOST_SELECTOR_STRATEGY "HOST_SELECTOR_STRATEGY"