diff --git a/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md b/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md
index 1cda4ed..7e57c6f 100644
--- a/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md
+++ b/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md
@@ -14,29 +14,33 @@ When a user wants access to a resource, it authenticates with the IdP. From this
1. Follow steps in [Enable AWS IAM Database Authentication](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.Enabling.html) to setup IAM authentication.
2. Configure Okta as the AWS identity provider following [Okta's official documentation](https://help.okta.com/en-us/content/topics/deploymentguides/aws/aws-deployment.htm)
+3. (Optional) Enable MFA. MFA through Okta Verify is supported for the Push and OTP methods. Please ensure the authentication policies and/or global session policies have been configured to use MFA.
### Connection String / DSN Configuration for Okta Authentication Plugin Support
-| Field | Connection Option Key | Value | Default Value | Sample Value |
-|-----------------------|------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|--------------------------------------------------------|
-| Authentication Type | `RDS_AUTH_TYPE` | Must be `OKTA`. | `database` | `OKTA` |
-| Server | `SERVER` | Database instance server host. | nil | `database.us-east-1-rds.amazon.com` |
-| Port | `PORT` | Port that the database is listening on. | nil | 5432 |
-| User Name | `UID` | Database user name for IAM authentication. | nil | `iam_user` |
-| IAM Host | `IAM_HOST` | The endpoint used to generate the authentication token. This is only required if you are connecting using custom endpoints such as an IP address. | nil | `database.us-east-1-rds.amazon.com` |
-| Region | `REGION` | The region of the database for IAM authentication. | `us-east-1` | `us-east-1` |
-| Database | `DATABASE` | Default database that a user will work on. | nil | `my_database` |
-| Token Expiration | `TOKEN_EXPIRATION` | Token expiration in seconds, supported max value is 900. | 900 | 900 |
-| IdP Endpoint | `IDP_ENDPOINT` | The ADFS host that is used to authenticate with. | nil | `my-adfs-host.com` |
-| IdP Port | `IDP_PORT` | The ADFS host port. | 443 | 443 |
-| IdP User Name | `IDP_USERNAME` | The user name for the IdP Endpoint server. | nil | `user@email.com` |
-| IdP Password | `IDP_PASSWORD` | The IdP user's password. | nil | `my_password_123` |
-| Role ARN | `IDP_ROLE_ARN` | The ARN of the IAM Role that is to be assumed for database access. | nil | `arn:aws:iam::123412341234:role/ADFS-SAML-Assume` |
-| IdP SAML Provider ARN | `IDP_SAML_ARN` | The ARN of the Identity Provider. | nil | `arn:aws:iam::123412341234:saml-provider/ADFS-AWS-IAM` |
-| HTTP Socket Timeout | `HTTP_SOCKET_TIMEOUT` | The socket timeout value in milliseconds for the HttpClient reading. | 3000 | 3000 |
-| HTTP Connect Timeout | `HTTP_CONNECT_TIMEOUT` | The connect timeout value in milliseconds for the HttpClient. | 5000 | 5000 |
-| App ID | `APP_ID` | The application ID for AWS configured on. | nil | `my-app-id` |
-| Extra URL Encode | `EXTRA_URL_ENCODE` | Generated tokens can have URL encoding prefix duplication for scenarios where underlying drivers automatically decode the URL before passing to the database for connections. | `0` | `1` |
+| Field | Connection Option Key | Value | Default Value | Sample Value |
+|-----------------------|------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|--------------------------------------------------------|
+| Authentication Type | `RDS_AUTH_TYPE` | Must be `OKTA`. | `database` | `OKTA` |
+| Server | `SERVER` | Database instance server host. | nil | `database.us-east-1-rds.amazon.com` |
+| Port | `PORT` | Port that the database is listening on. | nil | `5432` |
+| User Name | `UID` | Database user name for IAM authentication. | nil | `iam_user` |
+| IAM Host | `IAM_HOST` | The endpoint used to generate the authentication token. This is only required if you are connecting using custom endpoints such as an IP address. | nil | `database.us-east-1-rds.amazon.com` |
+| Region | `REGION` | The region of the database for IAM authentication. | `us-east-1` | `us-east-1` |
+| Database | `DATABASE` | Default database that a user will work on. | nil | `my_database` |
+| Token Expiration | `TOKEN_EXPIRATION` | Token expiration in seconds, supported max value is 900. | `900` | `900` |
+| IdP Endpoint | `IDP_ENDPOINT` | The ADFS host that is used to authenticate with. | nil | `my-adfs-host.com` |
+| IdP Port | `IDP_PORT` | The ADFS host port. | `443` | `443` |
+| IdP User Name | `IDP_USERNAME` | The user name for the IdP Endpoint server. | nil | `user@email.com` |
+| IdP Password | `IDP_PASSWORD` | The IdP user's password. | nil | `my_password_123` |
+| Role ARN | `IDP_ROLE_ARN` | The ARN of the IAM Role that is to be assumed for database access. | nil | `arn:aws:iam::123412341234:role/ADFS-SAML-Assume` |
+| IdP SAML Provider ARN | `IDP_SAML_ARN` | The ARN of the Identity Provider. | nil | `arn:aws:iam::123412341234:saml-provider/ADFS-AWS-IAM` |
+| HTTP Socket Timeout | `HTTP_SOCKET_TIMEOUT` | The socket timeout value in milliseconds for the HttpClient reading. | `3000` | `3000` |
+| HTTP Connect Timeout | `HTTP_CONNECT_TIMEOUT` | The connect timeout value in milliseconds for the HttpClient. | `5000` | `5000` |
+| App ID | `APP_ID` | The application ID for AWS configured on. | nil | `my-app-id` |
+| Extra URL Encode | `EXTRA_URL_ENCODE` | Generated tokens can have URL encoding prefix duplication for scenarios where underlying drivers automatically decode the URL before passing to the database for connections. | `0` | `1` |
+| MFA Type | `MFA_TYPE` | The MFA type the user specifies. The available options are: `TOTP`, `PUSH`. **Note**: the `TOTP` type requires a web browser to be used. | nil | `TOTP` |
+| MFA Port | `MFA_PORT` | The port used to connect to `127.0.0.1` to provide the one time code when using TOTP as the MFA Type. | `8080` | `8000` |
+| MFA Timeout | `MFA_TIMEOUT` | The time in seconds to complete the MFA challenge before the connection fails. | `60` | `30` |
> [!WARNING]\
> Using IAM Authentication, connections to the database must have SSL enabled. Please refer to the underlying driver's specifications to enable this.
diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt
index 07e6bbd..b478067 100644
--- a/driver/CMakeLists.txt
+++ b/driver/CMakeLists.txt
@@ -74,6 +74,17 @@ set(INC
${CMAKE_CURRENT_SOURCE_DIR}/host_info.h
${CMAKE_CURRENT_SOURCE_DIR}/odbcapi.h
${CMAKE_CURRENT_SOURCE_DIR}/odbcapi_rds_helper.h
+
+ # Webserver
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/AddrInformation.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/HtmlResponse.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Parser.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Selector.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Socket.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/SocketStream.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/SocketSupport.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer_utils.h
)
set(SRC
@@ -116,6 +127,15 @@ set(SRC
${CMAKE_CURRENT_SOURCE_DIR}/host_info.cpp
${CMAKE_CURRENT_SOURCE_DIR}/odbcapi_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/odbcapi_rds_helper.cpp
+
+ # Webserver
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/AddrInformation.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Parser.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Selector.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Socket.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/SocketStream.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer_utils.cpp
)
# GUI
diff --git a/driver/gui/odbcsetup.rc b/driver/gui/odbcsetup.rc
index 637c65d..6581987 100644
--- a/driver/gui/odbcsetup.rc
+++ b/driver/gui/odbcsetup.rc
@@ -148,43 +148,49 @@ STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD
FONT 8, "MS Shell Dlg", 400, 0, 0x1
BEGIN
COMBOBOX IDC_AUTH_MODE,62,7,96,30,CBS_DROPDOWN | CBS_SORT | WS_VSCROLL | WS_TABSTOP
- EDITTEXT IDC_REGION,62,24,96,14,ES_AUTOHSCROLL
- EDITTEXT IDC_EXPIRE,62,44,96,14,ES_AUTOHSCROLL
- EDITTEXT IDC_IAM_HOST,62,82,96,14,ES_AUTOHSCROLL | WS_DISABLED
- EDITTEXT IDC_IAM_PORT,62,101,96,14,ES_AUTOHSCROLL | WS_DISABLED
- EDITTEXT IDC_SECRET,62,120,96,14,ES_AUTOHSCROLL | WS_DISABLED
- EDITTEXT IDC_SECRET_REGION,62,139,96,14,ES_AUTOHSCROLL | WS_DISABLED
- EDITTEXT IDC_SECRET_END,62,159,96,14,ES_AUTOHSCROLL | WS_DISABLED
+ EDITTEXT IDC_REGION,62,22,96,14,ES_AUTOHSCROLL
+ EDITTEXT IDC_EXPIRE,62,39,96,14,ES_AUTOHSCROLL
+ EDITTEXT IDC_IAM_HOST,62,56,96,14,ES_AUTOHSCROLL | WS_DISABLED
+ EDITTEXT IDC_IAM_PORT,62,73,96,14,ES_AUTOHSCROLL | WS_DISABLED
+ EDITTEXT IDC_SECRET,62,90,96,14,ES_AUTOHSCROLL | WS_DISABLED
+ EDITTEXT IDC_SECRET_REGION,62,107,96,14,ES_AUTOHSCROLL | WS_DISABLED
+ EDITTEXT IDC_SECRET_END,62,124,96,14,ES_AUTOHSCROLL | WS_DISABLED
EDITTEXT IDC_IDP_UID,209,7,106,14,ES_AUTOHSCROLL
- EDITTEXT IDC_IDP_PWD,209,25,106,14,ES_PASSWORD | ES_AUTOHSCROLL
- EDITTEXT IDC_IDP_END,209,45,106,14,ES_AUTOHSCROLL
- EDITTEXT IDC_APP_ID,209,65,106,14,ES_AUTOHSCROLL | WS_DISABLED
- EDITTEXT IDC_ROLE_ARN,209,85,106,14,ES_AUTOHSCROLL
- EDITTEXT IDC_IDP_ARN,209,104,106,14,ES_AUTOHSCROLL
- EDITTEXT IDC_IDP_PORT,209,124,106,14,ES_AUTOHSCROLL
+ EDITTEXT IDC_IDP_PWD,209,24,106,14,ES_PASSWORD | ES_AUTOHSCROLL
+ EDITTEXT IDC_IDP_END,209,42,106,14,ES_AUTOHSCROLL
+ EDITTEXT IDC_APP_ID,209,146,106,14,ES_AUTOHSCROLL | WS_DISABLED
+ EDITTEXT IDC_ROLE_ARN,209,59,106,14,ES_AUTOHSCROLL
+ EDITTEXT IDC_IDP_ARN,209,77,106,14,ES_AUTOHSCROLL
+ EDITTEXT IDC_IDP_PORT,209,94,106,14,ES_AUTOHSCROLL
RTEXT "Authentication Mode:",IDC_STATIC,10,5,47,18,0,WS_EX_RIGHT
- RTEXT "IAM Host:",IDC_STATIC,23,84,33,8,0,WS_EX_RIGHT
- RTEXT "IAM Port:",IDC_STATIC,25,103,32,8,0,WS_EX_RIGHT
- RTEXT "Expire Time (ms):",IDC_STATIC,17,42,40,16,0,WS_EX_RIGHT
- LTEXT "Secret ID:",IDC_STATIC,23,122,34,8,0,WS_EX_RIGHT
- RTEXT "Secret Endpoint:",IDC_STATIC,25,156,32,17,0,WS_EX_RIGHT
- LTEXT "Secret Region:",IDC_STATIC,6,140,51,8,0,WS_EX_RIGHT
+ RTEXT "IAM Host:",IDC_STATIC,23,58,33,8,0,WS_EX_RIGHT
+ RTEXT "IAM Port:",IDC_STATIC,25,75,32,8,0,WS_EX_RIGHT
+ RTEXT "Expire Time (ms):",IDC_STATIC,17,37,40,16,0,WS_EX_RIGHT
+ LTEXT "Secret ID:",IDC_STATIC,23,92,34,8,0,WS_EX_RIGHT
+ RTEXT "Secret Endpoint:",IDC_STATIC,25,121,32,17,0,WS_EX_RIGHT
+ LTEXT "Secret Region:",IDC_STATIC,6,108,51,8,0,WS_EX_RIGHT
RTEXT "IDP Username:",IDC_STATIC,166,5,39,17,0,WS_EX_RIGHT
- RTEXT "IDP Password:",IDC_STATIC,168,23,37,18,0,WS_EX_RIGHT
- RTEXT "IDP Endpoint:",IDC_STATIC,169,44,36,16,0,WS_EX_RIGHT
- LTEXT "App ID:",IDC_STATIC,179,67,26,8,0,WS_EX_RIGHT
- RTEXT "IAM Role ARN:",IDC_STATIC,171,83,34,16,0,WS_EX_RIGHT
- RTEXT "IDP SAML ARN:",IDC_STATIC,173,103,32,15,0,WS_EX_RIGHT
- RTEXT "IDP Port:",IDC_STATIC,175,126,30,8,0,WS_EX_RIGHT
- CONTROL "",IDC_URL_ENCODE,"Button",BS_AUTOCHECKBOX | BS_LEFTTEXT | BS_MULTILINE | WS_TABSTOP,62,62,10,17
- LTEXT "Extra URL Encoding:",IDC_STATIC,21,61,36,18,0,WS_EX_RIGHT
- LTEXT "AWS Region:",IDC_STATIC,13,27,44,8,0,WS_EX_RIGHT
- EDITTEXT IDC_RELAY_PARTY_ID,209,144,106,14,ES_AUTOHSCROLL
- RTEXT "Relaying Party ID:",IDC_STATIC,172,142,33,16
- EDITTEXT IDC_CONNECT_TIMEOUT,209,164,106,14,ES_AUTOHSCROLL
- EDITTEXT IDC_SOCKET_TIMEOUT,62,178,96,14,ES_AUTOHSCROLL
- RTEXT "Connect Timeout:",IDC_STATIC,163,163,42,17
- RTEXT "Socket Timeout:",IDC_STATIC,18,177,39,17
+ RTEXT "IDP Password:",IDC_STATIC,168,22,37,18,0,WS_EX_RIGHT
+ RTEXT "IDP Endpoint:",IDC_STATIC,169,41,36,16,0,WS_EX_RIGHT
+ LTEXT "App ID:",IDC_STATIC,179,148,26,8,0,WS_EX_RIGHT
+ RTEXT "IAM Role ARN:",IDC_STATIC,171,57,34,16,0,WS_EX_RIGHT
+ RTEXT "IDP SAML ARN:",IDC_STATIC,173,76,32,15,0,WS_EX_RIGHT
+ RTEXT "IDP Port:",IDC_STATIC,175,96,30,8,0,WS_EX_RIGHT
+ CONTROL "",IDC_URL_ENCODE,"Button",BS_AUTOCHECKBOX | BS_LEFTTEXT | BS_MULTILINE | WS_TABSTOP,61,178,10,17
+ LTEXT "Extra URL Encoding:",IDC_STATIC,21,177,36,18,0,WS_EX_RIGHT
+ LTEXT "AWS Region:",IDC_STATIC,13,25,44,8,0,WS_EX_RIGHT
+ EDITTEXT IDC_RELAY_PARTY_ID,209,112,106,14,ES_AUTOHSCROLL
+ RTEXT "Relaying Party ID:",IDC_STATIC,172,110,33,16
+ EDITTEXT IDC_CONNECT_TIMEOUT,209,129,106,14,ES_AUTOHSCROLL
+ EDITTEXT IDC_SOCKET_TIMEOUT,62,141,96,14,ES_AUTOHSCROLL
+ RTEXT "Connect Timeout:",IDC_STATIC,163,128,42,17
+ RTEXT "Socket Timeout:",IDC_STATIC,18,140,39,17
+ COMBOBOX IDC_MFA_TYPE,62,161,96,30,CBS_DROPDOWN | CBS_SORT | WS_VSCROLL | WS_TABSTOP
+ LTEXT "MFA Type:",IDC_STATIC,32,159,25,19,0,WS_EX_RIGHT
+ EDITTEXT IDC_MFA_PORT,209,163,106,14,ES_AUTOHSCROLL
+ LTEXT "MFA Server Port:",IDC_STATIC,164,161,41,17,0,WS_EX_RIGHT
+ LTEXT "MFA Timeout (s):",IDC_STATIC,158,179,47,16,0,WS_EX_RIGHT
+ EDITTEXT IDC_MFA_TIMEOUT,209,181,106,14,ES_AUTOHSCROLL
END
IDC_TAB_FAILOVER DIALOGEX 5, 15, 322, 201
@@ -260,7 +266,7 @@ BEGIN
VERTGUIDE, 205
VERTGUIDE, 209
TOPMARGIN, 7
- BOTTOMMARGIN, 194
+ BOTTOMMARGIN, 195
END
IDC_TAB_FAILOVER, DIALOG
diff --git a/driver/gui/resource.h b/driver/gui/resource.h
index 99aa7db..14ee904 100644
--- a/driver/gui/resource.h
+++ b/driver/gui/resource.h
@@ -75,6 +75,10 @@
#define IDC_CONNECT_TIMEOUT 1059
#define IDC_SOCKET_TIMEOUT 1061
#define IDC_DB_DIALECT 1062
+#define IDC_MFA_PORT 1063
+#define IDC_MFA_TIMEOUT 1064
+#define IDC_MFA_TYPE 1065
+
// Next default values for new objects
//
@@ -82,7 +86,7 @@
#ifndef APSTUDIO_READONLY_SYMBOLS
#define _APS_NEXT_RESOURCE_VALUE 126
#define _APS_NEXT_COMMAND_VALUE 40001
-#define _APS_NEXT_CONTROL_VALUE 1063
+#define _APS_NEXT_CONTROL_VALUE 1066
#define _APS_NEXT_SYMED_VALUE 101
#endif
#endif
diff --git a/driver/gui/setup.cpp b/driver/gui/setup.cpp
index e29b561..efd9aa3 100644
--- a/driver/gui/setup.cpp
+++ b/driver/gui/setup.cpp
@@ -107,7 +107,10 @@ const std::map> FED_AUTH_KEYS = {
{KEY_IDP_PORT, {IDC_IDP_PORT, EDIT_TEXT}},
{KEY_RELAY_PARTY_ID, {IDC_RELAY_PARTY_ID, EDIT_TEXT}},
{KEY_HTTP_CONNECT_TIMEOUT, {IDC_CONNECT_TIMEOUT, EDIT_TEXT}},
- {KEY_HTTP_SOCKET_TIMEOUT, {IDC_SOCKET_TIMEOUT, EDIT_TEXT}}
+ {KEY_HTTP_SOCKET_TIMEOUT, {IDC_SOCKET_TIMEOUT, EDIT_TEXT}},
+ {KEY_MFA_TYPE, {IDC_MFA_TYPE, COMBO}},
+ {KEY_MFA_PORT, {IDC_MFA_PORT, EDIT_TEXT}},
+ {KEY_MFA_TIMEOUT, {IDC_MFA_TIMEOUT, EDIT_TEXT}}
};
const std::map> FAILOVER_KEYS = {
@@ -164,6 +167,12 @@ const std::vector> DB_DIALECTS = {
{"Aurora PostgreSQL Limitless", VALUE_DB_DIALECT_AURORA_POSTGRESQL_LIMITLESS}
};
+const std::vector> MFA_TYPES = {
+ {"None", ""},
+ {"TOTP", VALUE_MFA_TOTP},
+ {"Push", VALUE_MFA_PUSH}
+};
+
HINSTANCE ghInstance;
HWND tab_control;
HWND aws_auth_tab;
@@ -252,6 +261,8 @@ std::string GetControlValue(HWND hwnd, std::pair pair)
return LIMITLESS_MODES[selection].second;
case IDC_DB_DIALECT:
return DB_DIALECTS[selection].second;
+ case IDC_MFA_TYPE:
+ return MFA_TYPES[selection].second;
default:
break;
}
@@ -441,7 +452,6 @@ void TestConnection(HWND hwnd)
RDS_AllocDbc(henv, &hdbc);
std::string test_conn_str = GetDsn(true);
-
SQLRETURN ret = RDS_SQLDriverConnect(
hdbc,
nullptr,
@@ -712,7 +722,10 @@ void HandleAuthModeSelection(HWND hwnd) {
if (!IAM_KEYS.contains(keys.first) &&
!FED_AUTH_KEYS.contains(keys.first) &&
!AUTH_KEYS.contains(keys.first) ||
- id == IDC_APP_ID ) {
+ id == IDC_APP_ID ||
+ id == IDC_MFA_TYPE ||
+ id == IDC_MFA_PORT ||
+ id == IDC_MFA_TIMEOUT) {
show_ctrl = false;
}
break;
@@ -761,7 +774,13 @@ BOOL AuthTabInit(HWND hwnd, HWND hwndFocus, LPARAM lParam)
ComboBox_InsertString(auth_mode_box, i, RDS_TSTR(AWS_AUTH_MODES[i].first).c_str());
}
+ HWND mfa_type_box = GetDlgItem(hwnd, IDC_MFA_TYPE);
+ for (int i = 0; i < MFA_TYPES.size(); i++) {
+ ComboBox_InsertString(mfa_type_box, i, RDS_TSTR(MFA_TYPES[i].first).c_str());
+ }
+
SetInitialComboBoxValue(hwnd, IDC_AUTH_MODE, KEY_AUTH_TYPE, AWS_AUTH_MODES);
+ SetInitialComboBoxValue(hwnd, IDC_MFA_TYPE, KEY_MFA_TYPE, MFA_TYPES);
HandleAuthModeSelection(hwnd);
diff --git a/driver/plugin/federated/http/AddrInformation.cpp b/driver/plugin/federated/http/AddrInformation.cpp
new file mode 100755
index 0000000..59bcf6b
--- /dev/null
+++ b/driver/plugin/federated/http/AddrInformation.cpp
@@ -0,0 +1,120 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "AddrInformation.h"
+
+#include
+#include
+
+AddrInformation::AddrInformation( const std::string& port)
+{
+ addrinfo hints;
+
+ memset(&hints, 0, sizeof(hints));
+
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_PASSIVE;
+
+ if (getaddrinfo("127.0.0.1", port.c_str(), &hints, &addrinfo_) != 0) {
+ throw std::runtime_error("Unable to get address information.");
+ }
+}
+
+AddrInformation::~AddrInformation()
+{
+ freeaddrinfo(addrinfo_);
+}
+
+AddrInformationIterator AddrInformation::begin()
+{
+ return AddrInformationIterator(addrinfo_);
+}
+
+AddrInformationIterator AddrInformation::end()
+{
+ return AddrInformationIterator(nullptr);
+}
+
+AddrInformationIterator AddrInformation::begin() const
+{
+ return AddrInformationIterator(addrinfo_);
+}
+
+AddrInformationIterator AddrInformation::end() const
+{
+ return AddrInformationIterator(nullptr);
+}
+
+AddrInformationIterator::AddrInformationIterator(addrinfo* addr) : addr_(addr)
+{
+ ; // Do nothing.
+}
+
+AddrInformationIterator::AddrInformationIterator(const AddrInformationIterator& itr)
+ : addr_(itr.addr_)
+{
+ ; // Do nothing.
+}
+
+AddrInformationIterator::~AddrInformationIterator()
+{
+ ; // Do nothing.
+}
+
+AddrInformationIterator AddrInformationIterator::operator++(int)
+{
+ AddrInformationIterator tmp(*this);
+
+ if (addr_)
+ {
+ addr_ = addr_->ai_next;
+ }
+
+ return tmp;
+}
+
+AddrInformationIterator& AddrInformationIterator::operator++()
+{
+ if (addr_) {
+ addr_ = addr_->ai_next;
+ }
+
+ return *this;
+}
+
+bool AddrInformationIterator::operator!=(const AddrInformationIterator& itr) const
+{
+ return addr_ != itr.addr_;
+}
+
+bool AddrInformationIterator::operator==(const AddrInformationIterator& itr) const
+{
+ return addr_ == itr.addr_;
+}
+
+addrinfo* AddrInformationIterator::operator->()
+{
+ return addr_;
+}
+
+const addrinfo* AddrInformationIterator::operator*() const
+{
+ return addr_;
+}
+
+addrinfo* AddrInformationIterator::operator*()
+{
+ return addr_;
+}
diff --git a/driver/plugin/federated/http/AddrInformation.h b/driver/plugin/federated/http/AddrInformation.h
new file mode 100755
index 0000000..f20de13
--- /dev/null
+++ b/driver/plugin/federated/http/AddrInformation.h
@@ -0,0 +1,73 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "SocketSupport.h"
+#include
+
+/*
+* This class is used to support range-based for loop and iterators
+* in AddrInformation class.
+*/
+class AddrInformationIterator {
+public:
+ AddrInformationIterator(addrinfo *addr);
+
+ AddrInformationIterator(const AddrInformationIterator& itr);
+
+ ~AddrInformationIterator();
+
+ AddrInformationIterator& operator++();
+
+ AddrInformationIterator operator++(int);
+
+ bool operator!=(const AddrInformationIterator& itr) const;
+
+ bool operator==(const AddrInformationIterator& itr) const;
+
+ addrinfo* operator->();
+
+ const addrinfo* operator*() const;
+
+ addrinfo* operator*();
+
+private:
+ addrinfo *addr_;
+
+};
+
+/*
+* This class is used to perform network address and service
+* translation via getaddrinfo call. Returns one or more addrinfo structures, each
+* of which contains an Internet address.
+*/
+class AddrInformation
+{
+public:
+ AddrInformation( const std::string& port);
+
+ ~AddrInformation();
+
+ AddrInformationIterator begin();
+
+ AddrInformationIterator end();
+
+ AddrInformationIterator begin() const;
+
+ AddrInformationIterator end() const;
+
+private:
+ addrinfo *addrinfo_;
+};
diff --git a/driver/plugin/federated/http/HtmlResponse.h b/driver/plugin/federated/http/HtmlResponse.h
new file mode 100755
index 0000000..bcfd15c
--- /dev/null
+++ b/driver/plugin/federated/http/HtmlResponse.h
@@ -0,0 +1,44 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+inline std::string GetResponse =
+ "HTTP/1.1 200 OK\r\n"
+ "Content-Length: 300\r\n"
+ "Connection: close\r\n"
+ "Content-Type: text/html; charset=utf-8\r\n\r\n"
+ "
"
+ "Thank you for using the AWS Advanced ODBC Wrapper! You can now close this window.
";
+
+const std::string InvalidResponse =
+ "HTTP/1.1 400 Bad Request\r\n"
+ "Content-Length: 95\r\n"
+ "Connection: close\r\n"
+ "Content-Type: text/html; charset=utf-8\r\n\r\n"
+ "The request could not be understood by the server!
";
diff --git a/driver/plugin/federated/http/Parser.cpp b/driver/plugin/federated/http/Parser.cpp
new file mode 100755
index 0000000..63195b4
--- /dev/null
+++ b/driver/plugin/federated/http/Parser.cpp
@@ -0,0 +1,174 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "../../../util/logger_wrapper.h"
+#include "Parser.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+void Parser::ParseRequestLine(std::string& str)
+{
+ std::istringstream istr(str);
+ std::vector command(
+ (std::istream_iterator(istr)),
+ std::istream_iterator()
+ );
+
+ if ((command.size() == 3) && (command[0].find(METHOD) != std::string::npos) &&
+ (command[2].find(HTTP_VERSION) != std::string::npos)) {
+ parser_state_ = STATE::PARSE_HEADER;
+ } else if ((command.size() == 3) &&
+ (command[0].find(GET_METHOD) != std::string::npos) &&
+ command[2].find(HTTP_VERSION) != std::string::npos) {
+ const size_t question_mark_pos = command[1].find('?');
+ if (question_mark_pos != std::string::npos) {
+ std::string parsed_body = command[1].substr(question_mark_pos + 1);
+ ParseBodyLine(parsed_body);
+ } else {
+ parser_state_ = STATE::PARSE_GET_REQUEST;
+ }
+ } else {
+ throw std::runtime_error("Request line contains wrong information.");
+ }
+}
+
+void Parser::ParseHeaderLine(std::string& str)
+{
+ if (str == "\r") {
+ parser_state_ = STATE::PARSE_BODY;
+
+ return;
+ }
+
+ const size_t ind = str.find(':');
+ header_size_ += str.size();
+
+ if ((ind == std::string::npos) || (header_size_ > MAX_HEADER_SIZE)) {
+ throw std::runtime_error("Received invalid header line.");
+ }
+
+ str.erase(std::remove_if(str.begin() + static_cast(ind), str.end(), ::isspace), str.end());
+
+ header_.insert({
+ std::string(str.begin(), str.begin() + static_cast(ind)),
+ std::string(str.begin() + static_cast(ind) + 1, str.end())
+ });
+}
+
+void Parser::ParseBodyLine(std::string& str)
+{
+ auto str_begin = str.begin();
+ auto str_end = str.end();
+
+ while (true) {
+ auto equal_it = find(str_begin, str_end, '=');
+ auto ampersand_it = find(str_begin, str_end, '&');
+
+ if (equal_it == str_end) {
+ throw std::runtime_error("Received invalid body line.");
+ }
+
+ body_.insert({
+ std::string(str_begin, equal_it),
+ std::string(equal_it + 1, ampersand_it)
+ });
+
+ if ((equal_it == str_end) || (ampersand_it == str_end)) {
+ break;
+ }
+
+ str_begin = ampersand_it + 1;
+ }
+
+ parser_state_ = STATE::PARSE_FINISHED;
+}
+
+void Parser::ParsePostRequest(std::string& str)
+{
+ switch (parser_state_) {
+ case STATE::PARSE_REQUEST:
+ ParseRequestLine(str);
+ break;
+ case STATE::PARSE_HEADER:
+ ParseHeaderLine(str);
+ break;
+ case STATE::PARSE_BODY:
+ if (!header_.contains("Content-Type") || !header_.contains("Content-Length") ||
+ (header_["Content-Type"] != "application/x-www-form-urlencoded") ||
+ (header_["Content-Length"] != std::to_string(str.size())))
+ {
+ throw std::runtime_error("Can't start parsing body as header contains invalid information.");
+ }
+
+ ParseBodyLine(str);
+ break;
+ default:
+ break;
+ }
+}
+
+STATUS Parser::Parse(std::istream &in)
+{
+ std::string str;
+ bool is_line_received = false;
+
+ while (getline(in, str) && parser_state_ != STATE::PARSE_GET_REQUEST) {
+ is_line_received = true;
+
+ try {
+ ParsePostRequest(str);
+ } catch (std::exception& e) {
+ DLOG(INFO) << e.what();
+ break;
+ }
+ }
+
+ if (parser_state_ == STATE::PARSE_GET_REQUEST) {
+ parser_state_ = STATE::PARSE_REQUEST;
+ return STATUS::GET_SUCCESS;
+ }
+
+ if (parser_state_ != STATE::PARSE_FINISHED) {
+ return is_line_received ? STATUS::FAILED : STATUS::EMPTY_REQUEST;
+ }
+
+ return STATUS::SUCCEED;
+}
+
+Parser::Parser()
+ : parser_state_(STATE::PARSE_REQUEST)
+ , header_size_(0)
+{
+ ; // Do nothing.
+}
+
+bool Parser::IsFinished() const
+{
+ return parser_state_ == STATE::PARSE_FINISHED;
+}
+
+std::string Parser::RetrieveAuthCode(std::string& state)
+{
+ if (body_.contains("code"))
+ {
+ return body_["code"];
+ }
+
+ return "";
+}
diff --git a/driver/plugin/federated/http/Parser.h b/driver/plugin/federated/http/Parser.h
new file mode 100755
index 0000000..6b13934
--- /dev/null
+++ b/driver/plugin/federated/http/Parser.h
@@ -0,0 +1,114 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include
+
+enum class STATE
+{
+ PARSE_REQUEST,
+ PARSE_HEADER,
+ PARSE_BODY,
+ PARSE_FINISHED,
+ PARSE_GET_REQUEST
+};
+
+enum class STATUS
+{
+ SUCCEED,
+ FAILED,
+ EMPTY_REQUEST,
+ GET_SUCCESS
+};
+
+/*
+* This class is used to parse the HTTP POST request
+* and retrieve authorization code.
+*/
+class Parser {
+ public:
+
+ Parser();
+
+ ~Parser() = default;
+
+ /*
+ * Initiate request parsing.
+ *
+ * Return true if parse was successful, otherwise false.
+ */
+ STATUS Parse(std::istream &in);
+
+ /*
+ * Check if parser is finished to parse the POST request.
+ *
+ * Return true if parsing was successfully finished, otherwise false.
+ */
+ bool IsFinished() const;
+
+ /*
+ * Retrieve authorization code.
+ *
+ * Return received authorization code or empty string.
+ */
+ std::string RetrieveAuthCode(std::string& state);
+
+ private:
+ /*
+ * Parse received POST request line by line.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParsePostRequest(std::string& str);
+
+ /*
+ * Parse request-line and perform verification for:
+ * method, Request-URI and HTTP-Version.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParseRequestLine(std::string& str);
+
+ /*
+ * Parse request header line.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParseHeaderLine(std::string& str);
+
+ /*
+ * Parse request body line in application/x-www-form-urlencoded format.
+ *
+ * Return void or throw an exception if parse wasn't successful.
+ */
+ void ParseBodyLine(std::string& str);
+
+ /*
+ * Expected METHOD, URI and HTTP VERSION in request line.
+ */
+ const std::string METHOD = "POST";
+ // const std::string URI = "/redshift/";
+ const std::string HTTP_VERSION = "HTTP/1.1";
+ const int MAX_HEADER_SIZE = 8192;
+
+ const std::string GET_METHOD = "GET";
+ const std::string PKCE_URI = "/?code=";
+
+ STATE parser_state_;
+ size_t header_size_;
+ std::unordered_map header_;
+ std::unordered_map body_;
+
+};
diff --git a/driver/plugin/federated/http/Selector.cpp b/driver/plugin/federated/http/Selector.cpp
new file mode 100644
index 0000000..5dd7f79
--- /dev/null
+++ b/driver/plugin/federated/http/Selector.cpp
@@ -0,0 +1,56 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "Selector.h"
+
+void Selector::Register(SOCKET sfd)
+{
+ if (sfd == -1) {
+ return;
+ }
+
+ FD_SET(sfd, &master_fds_);
+
+#ifdef WIN32
+ max_fd_ = max(max_fd_, sfd);
+#else
+ max_fd_ = std::max(max_fd_, sfd);
+#endif
+}
+
+void Selector::Unregister(SOCKET sfd)
+{
+ FD_CLR(sfd, &master_fds_);
+}
+
+bool Selector::Select(struct timeval *tv)
+{
+ // As select will modify the file descriptor set we should keep temporary set to reflect the ready fd.
+ fd_set read_fds = master_fds_;
+
+ if (select(max_fd_ + 1, &read_fds, nullptr, nullptr, tv) > 0) {
+ for (SOCKET sfd = 0; sfd <= max_fd_; sfd++) {
+ if (FD_ISSET(sfd, &read_fds)) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+Selector::Selector() : max_fd_(0)
+{
+ FD_ZERO(&master_fds_);
+}
diff --git a/driver/plugin/federated/http/Selector.h b/driver/plugin/federated/http/Selector.h
new file mode 100755
index 0000000..d48a7de
--- /dev/null
+++ b/driver/plugin/federated/http/Selector.h
@@ -0,0 +1,52 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "SocketSupport.h"
+
+#include
+
+/*
+* This class is used to support insertion/deletion operations on
+* master descriptor set and selecting incoming connections.
+*/
+class Selector
+{
+ public:
+ Selector();
+
+ ~Selector() = default;
+
+ /*
+ * Update master descriptor set with new socket.
+ */
+ void Register(SOCKET sfd);
+
+ /*
+ * Remove socket from master descriptor set.
+ */
+ void Unregister(SOCKET sfd);
+
+ /*
+ * If any incoming connections is readable return true, or
+ * false otherwise.
+ */
+ bool Select(struct timeval* tv);
+
+ private:
+
+ SOCKET max_fd_;
+ fd_set master_fds_;
+};
diff --git a/driver/plugin/federated/http/Socket.cpp b/driver/plugin/federated/http/Socket.cpp
new file mode 100755
index 0000000..328f7f9
--- /dev/null
+++ b/driver/plugin/federated/http/Socket.cpp
@@ -0,0 +1,214 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "AddrInformation.h"
+#include "Socket.h"
+
+#include
+#include
+#include
+
+int Socket::Receive(char *buffer, int length, int flags) const
+{
+ int nbytes = 0;
+ int filled = 0;
+ auto start = std::chrono::system_clock
+ ::now();
+
+ const int receive_wait = 200;
+ // Give a chance to fully receive packet in case if there is no data in
+ // non-blocking socket.
+ while ((std::chrono::system_clock::now() - start < std::chrono::seconds(1))) {
+ nbytes = static_cast(recv(socket_fd_, buffer + filled, length - filled - 1, 0));
+
+ if (nbytes <= 0) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(receive_wait));
+ } else {
+ filled += nbytes;
+ }
+ }
+
+ return filled == 0 ? nbytes : filled;
+}
+
+int Socket::Send(const char *buffer, int length, int flags) const
+ {
+ int nbytes = 0;
+ int sent = 0;
+
+ while ((nbytes = static_cast(send(socket_fd_, buffer + sent, length - sent, flags))) > 0) {
+ sent += nbytes;
+ }
+
+ return sent == 0 ? nbytes : sent;
+}
+
+void Socket::PrepareListenSocket(const std::string& port)
+{
+ const AddrInformation addr_info(port);
+
+ for (const auto& ptr : addr_info) {
+ socket_fd_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
+
+ if (socket_fd_ == INVALID_SOCKET) {
+ continue;
+ }
+
+ if (!SetReusable()) {
+ Close();
+
+ throw std::runtime_error("Unable to reuse the address.");
+ }
+
+ if (Bind(ptr->ai_addr, ptr->ai_addrlen) == 0) {
+ break;
+ }
+
+ Close();
+ }
+
+ if (!SetNonBlocking()) {
+ Close();
+
+ throw std::runtime_error("Unable to set socket to non-blocking mode.");
+ }
+
+ if ((socket_fd_ == INVALID_SOCKET) || (Listen(CONNECTION_BACKLOG))) {
+ Close();
+
+ throw std::runtime_error("Can not start listening on port: " + port);
+ }
+}
+
+bool Socket::Close()
+{
+ if (socket_fd_ == -1) {
+ return false;
+ }
+
+#if (defined(_WIN32) || defined(_WIN64))
+ int res = closesocket(socket_fd_);
+#else
+ const int res = close(socket_fd_);
+#endif
+
+ socket_fd_ = INVALID_SOCKET;
+
+ return res == 0;
+}
+
+Socket::Socket() : socket_fd_(INVALID_SOCKET)
+{
+ ; // Do nothing.
+}
+
+Socket::Socket(SOCKET sfd) : socket_fd_(sfd)
+{
+ ; // Do nothing.
+}
+
+Socket::Socket(Socket&& s)
+{
+ socket_fd_ = std::move(s.socket_fd_);
+
+ s.socket_fd_ = INVALID_SOCKET;
+}
+
+Socket& Socket::operator=(Socket&& s)
+{
+ socket_fd_ = std::move(s.socket_fd_);
+
+ s.socket_fd_ = INVALID_SOCKET;
+
+ return *this;
+}
+
+Socket::~Socket()
+{
+ Close();
+}
+
+bool Socket::SetNonBlocking() const
+{
+#if (defined(_WIN32) || defined(_WIN64))
+ unsigned long mode = 1;
+ return ioctlsocket(socket_fd_, FIONBIO, &mode) == 0 ? true : false;
+#else
+ int flags = fcntl(socket_fd_, F_GETFL, 0);
+ return fcntl(socket_fd_, F_SETFL, flags | O_NONBLOCK) == 0 ? true : false;
+#endif
+}
+
+bool Socket::IsNonBlockingError() const
+{
+#if (defined(_WIN32) || defined(_WIN64))
+ return WSAGetLastError() == WSAEWOULDBLOCK;
+#else
+ return errno == EWOULDBLOCK;
+#endif
+}
+
+void Socket::Register(Selector& selector) const
+{
+ selector.Register(socket_fd_);
+}
+
+void Socket::Unregister(Selector& selector) const
+{
+ selector.Unregister(socket_fd_);
+}
+
+int Socket::Bind(const struct sockaddr *address, size_t address_len) const
+{
+ return bind(socket_fd_, address, (int)address_len);
+}
+
+int Socket::GetListenPort() const
+{
+ sockaddr_storage addr;
+ socklen_t len = sizeof(addr);
+ int port = 0;
+
+ getsockname(socket_fd_, (struct sockaddr*)&addr, &len);
+ port = htons(((sockaddr_in*)&addr)->sin_port);
+
+ return port;
+}
+
+int Socket::Listen(int backlog) const
+{
+ return listen(socket_fd_, backlog);
+}
+
+Socket Socket::Accept() const
+{
+ sockaddr_storage remoteaddr;
+ socklen_t addrlen = sizeof(remoteaddr);
+
+ return Socket(accept(socket_fd_, (struct sockaddr*)&remoteaddr, &addrlen));
+}
+
+bool Socket::SetReusable() const
+{
+ int yes = 1;
+
+ // Windows: int setsockopt(SOCKET s, int level, int optname, const char *optval, int optlen);
+ // The optval parameter has ponter to const char, but according to the MSDN we should use int:
+ // To enable a Boolean option, the optval parameter points to a nonzero integer.
+ // To disable the option optval points to an integer equal to zero.
+ // The optlen parameter should be equal to sizeof(int) for Boolean options.
+ // On Linux the optval parameter is pointer to void.
+ return setsockopt(socket_fd_, SOL_SOCKET, SO_REUSEADDR,
+ (char *)&yes, sizeof(yes)) == 0;
+}
diff --git a/driver/plugin/federated/http/Socket.h b/driver/plugin/federated/http/Socket.h
new file mode 100755
index 0000000..2d4e465
--- /dev/null
+++ b/driver/plugin/federated/http/Socket.h
@@ -0,0 +1,126 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOCKET_H_
+#define SOCKET_H_
+
+#pragma once
+
+#include "SocketSupport.h"
+#include "Selector.h"
+#include
+
+/*
+* This class is used to wrap the network socket
+* in order to have cross-platfrom code to send and receive
+* data from incoming connections.
+*/
+class Socket
+{
+ public:
+ Socket();
+
+ Socket( SOCKET sfd);
+
+ Socket(const Socket& s) = delete;
+
+ Socket& operator=(const Socket& s) = delete;
+
+ Socket(Socket&& s);
+
+ Socket& operator=(Socket&& s);
+
+ ~Socket();
+
+ /*
+ * Close a socket file descriptor, return true on success
+ * or false in case of error.
+ */
+ bool Close();
+
+ /*
+ * Get port where socket is listening.
+ */
+ int GetListenPort() const;
+
+ /*
+ * Listen for connections on a socket and return zero on success,
+ * or -1 in case of error.
+ */
+ int Listen(int backlog) const;
+
+ /*
+ * Accept a connection on a socket and return StreamSocket object.
+ */
+ Socket Accept() const;
+
+ /*
+ * Forcibly bind to a port in use by another socket and return true on success,
+ * or false in case of error.
+ */
+ bool SetReusable() const;
+
+ /*
+ * Set socket to non-blocking mode and return true on success
+ * or false in case of error.
+ */
+ bool SetNonBlocking() const;
+
+ /*
+ * Return true if error is caused by non-blocking mode of socket
+ * otherwise return false.
+ */
+ bool IsNonBlockingError() const;
+
+ /*
+ * Prepare socket to handle incoming connections.
+ */
+ void PrepareListenSocket(const std::string& port);
+
+ /*
+ * Register socket in master file descriptor set using Selector class.
+ */
+ void Register(Selector& selector) const;
+
+ /*
+ * Clear socket in master file descriptor set using Selector class.
+ */
+ void Unregister(Selector& selector) const;
+
+ /*
+ * Receive a message from a socket and return the number of bytes received,
+ * or -1 if an error occurred.
+ */
+ int Receive(char *buffer, int length, int flags) const;
+
+ /*
+ * Send a message on a socket and return the number of bytes sent,
+ * or -1 if an error occured.
+ */
+ int Send(const char *buffer, int length, int flags) const;
+
+ /*
+ * Bind a name to a socket and return zero on success,
+ * or -1 in case of error.
+ */
+ int Bind(const struct sockaddr *address, size_t address_len) const;
+
+ private:
+
+ const int CONNECTION_BACKLOG = 10;
+
+ SOCKET socket_fd_;
+};
+
+#endif // SOCKET_H_
diff --git a/driver/plugin/federated/http/SocketStream.cpp b/driver/plugin/federated/http/SocketStream.cpp
new file mode 100644
index 0000000..1d41a3c
--- /dev/null
+++ b/driver/plugin/federated/http/SocketStream.cpp
@@ -0,0 +1,50 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "SocketStream.h"
+
+SocketStream::SocketStream(Socket& s) : received_size_(0), socket_(s)
+{
+ setg(input_buffer_, input_buffer_, input_buffer_);
+}
+
+SocketStream::~SocketStream()
+{
+ sync();
+}
+
+int SocketStream::underflow()
+{
+ if (gptr() < egptr()) {
+ return traits_type::to_int_type(*gptr());
+ }
+
+ int received_bytes = socket_.Receive(input_buffer_, SIZE - 1, 0);
+
+ /*
+ * Return EOF in the following cases:
+ * If the received bytes less than zero (error situation);
+ * If the received bytes equal to zero (socket peer has performed an orderly shutdown);
+ * If the length of received packets more than MAX_SIZE.
+ */
+ if ((received_bytes <= 0) || (received_size_ + received_bytes > MAX_SIZE)) {
+ return traits_type::eof();
+ }
+
+ received_size_ += received_bytes;
+
+ setg(input_buffer_, input_buffer_, input_buffer_ + received_bytes);
+
+ return traits_type::to_int_type(*gptr());
+}
\ No newline at end of file
diff --git a/driver/plugin/federated/http/SocketStream.h b/driver/plugin/federated/http/SocketStream.h
new file mode 100755
index 0000000..5d1eb7b
--- /dev/null
+++ b/driver/plugin/federated/http/SocketStream.h
@@ -0,0 +1,59 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOCKET_STREAM_H_
+#define SOCKET_STREAM_H_
+
+#pragma once
+
+#include "Socket.h"
+
+#include
+
+/*
+* This class is used to wrap socket by std::stream.
+*/
+class SocketStream : public std::streambuf
+{
+ public:
+ /*
+ * Construct object and set the value for the pointers that define
+ * the boundaries of the buffered portion of the controlled INPUT sequence.
+ */
+ SocketStream(Socket& s);
+
+ /*
+ * Call sync inside the destructor to synchronize the contents in the stream buffer
+ * with those of the associated character sequence.
+ */
+ virtual ~SocketStream();
+
+ protected:
+ /*
+ * Virtual function called by other member functions to get the current character
+ * in the controlled input sequence without changing the current position.
+ * It is called by public member functions such as sgetc to request
+ * a new character when there are no read positions available at the get pointer (gptr).
+ */
+ virtual int underflow();
+
+ static const int SIZE = 2048;
+ static const int MAX_SIZE = 16384;
+
+ int received_size_;
+ char input_buffer_[SIZE];
+ Socket& socket_;
+};
+
+#endif // SOCKET_STREAM_H_
diff --git a/driver/plugin/federated/http/SocketSupport.h b/driver/plugin/federated/http/SocketSupport.h
new file mode 100644
index 0000000..074c0d2
--- /dev/null
+++ b/driver/plugin/federated/http/SocketSupport.h
@@ -0,0 +1,35 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#if (defined(_WIN32) || defined(_WIN64))
+
+#include
+#include
+#include "Ws2tcpip.h"
+#pragma comment(lib, "Ws2_32.lib")
+
+#else /* Linux or MAC */
+
+#include
+#include
+#include
+#include
+#include
+
+typedef int SOCKET;
+#define INVALID_SOCKET -1
+
+#endif
diff --git a/driver/plugin/federated/http/WEBServer.cpp b/driver/plugin/federated/http/WEBServer.cpp
new file mode 100755
index 0000000..1b84cf5
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer.cpp
@@ -0,0 +1,153 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "../../../util/logger_wrapper.h"
+#include "HtmlResponse.h"
+#include "SocketStream.h"
+#include "WEBServer.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+WEBServer::WEBServer( std::string& state,
+ std::string& port, std::string& timeout) :
+ state_(state),
+ port_(port),
+ timeout_(std::stoi(timeout)),
+ selector_(),
+ parser_(),
+ listen_socket_(),
+ listen_port_(0),
+ connections_counter_(0),
+ listening_(false)
+{
+ ; // Do nothing.
+}
+
+void WEBServer::LaunchServer()
+{
+ // Create thread to launch the server to listen the incoming connection.
+ // Waiting for the redirect response from the /oauth2/authorize.
+ thread_ = std::thread(&WEBServer::ListenerThread, this);
+}
+
+void WEBServer::Join()
+{
+ if (thread_.joinable()) {
+ thread_.join();
+ }
+}
+
+std::string WEBServer::GetCode() const
+{
+ return code_;
+}
+
+int WEBServer::GetListenPort() const
+{
+ return listen_port_;
+}
+
+bool WEBServer::IsListening() const
+{
+ return listening_.load();
+}
+
+bool WEBServer::IsTimeout() const
+{
+ return connections_counter_ > 0 ? false : true;
+}
+
+void WEBServer::Cancel() {
+ cancel_ = true;
+}
+
+bool WEBServer::WEBServerInit()
+{
+ // Prepare the environment for get the socket description.
+ try {
+ listen_socket_.PrepareListenSocket(port_);
+ listen_socket_.Register(selector_);
+ listen_port_ = listen_socket_.GetListenPort();
+ } catch (std::exception& e) {
+ DLOG(INFO) << "Exception: " << e.what();
+
+ return false;
+ }
+
+ return true;
+}
+
+void WEBServer::HandleConnection()
+{
+ /* Trying to accept the pending incoming connection. */
+ Socket ssck(listen_socket_.Accept());
+
+ ++connections_counter_;
+
+ if (ssck.SetNonBlocking()) {
+ SocketStream socket_buffer(ssck);
+ std::istream socket_stream(&socket_buffer);
+
+ const STATUS status = parser_.Parse(socket_stream);
+
+ if (status == STATUS::SUCCEED) {
+ ssck.Send(ValidResponse.c_str(), static_cast(ValidResponse.size()), 0);
+ } else if (status == STATUS::FAILED) {
+ ssck.Send(InvalidResponse.c_str(), static_cast(InvalidResponse.size()), 0);
+ } else if (status == STATUS::GET_SUCCESS) {
+ const std::string response = std::vformat(GetResponse, std::make_format_args(listen_port_));
+ ssck.Send(response.c_str(), static_cast(response.size()), 0);
+ } else {
+ /* Nothing is received from socket. Continue to listen. */
+ return;
+ }
+ }
+}
+
+void WEBServer::Listen()
+{
+ // Set timeout for non-blocking socket to 1 sec to pass it to Select.
+ struct timeval tv = { .tv_sec=1, .tv_usec=0 };
+
+ if (selector_.Select(&tv)) {
+ HandleConnection();
+ }
+}
+
+void WEBServer::ListenerThread()
+{
+ if (!WEBServerInit()) {
+ DLOG(INFO) << "WEBServerInit Failed";
+ return;
+ }
+
+ auto start = std::chrono::system_clock::now();
+
+ listening_.store(true);
+
+ while ((std::chrono::system_clock::now() - start < std::chrono::seconds(timeout_)) && !parser_.IsFinished()) {
+ Listen();
+ }
+
+ code_ = parser_.RetrieveAuthCode(state_);
+
+ listen_socket_.Close();
+}
diff --git a/driver/plugin/federated/http/WEBServer.h b/driver/plugin/federated/http/WEBServer.h
new file mode 100755
index 0000000..4be1205
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer.h
@@ -0,0 +1,118 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef WEBSERVER_H_
+#define WEBSERVER_H_
+
+#pragma once
+
+#include "Parser.h"
+#include "Selector.h"
+#include "Socket.h"
+
+#include
+#include
+
+/*
+* This class is used to launch the HTTP WEB server in separate thread
+* to wait for the redirect response from the /oauth2/authorize and
+* extract the authorization code from it.
+*/
+class WEBServer
+{
+ public:
+ WEBServer(
+ std::string& state,
+ std::string& port,
+ std::string& timeout);
+
+ ~WEBServer() = default;
+
+ /*
+ * Launch the HTTP WEB server in separate thread.
+ */
+ void LaunchServer();
+
+ /*
+ * Wait until HTTP WEB server is finished.
+ */
+ void Join();
+
+ /*
+ * Extract the authorization code from response.
+ */
+ std::string GetCode() const;
+
+ /*
+ * Get port where server is listening.
+ */
+ int GetListenPort() const;
+
+ /*
+ * Extract the SAML Assertion from response.
+ */
+ std::string GetSamlResponse() const;
+
+ /*
+ * If server is listening for connections return true, otherwise return false.
+ */
+ bool IsListening() const;
+
+ /*
+ * If timeout happened return true, otherwise return false.
+ */
+ bool IsTimeout() const;
+
+ /*
+ * Cancel listen loop prematurely without waiting for the timeout.
+ */
+ void Cancel();
+
+ private:
+ /*
+ * Main HTTP WEB server function that perform initialization and
+ * listen for incoming connections for specified time by user.
+ */
+ void ListenerThread();
+
+ /*
+ * If incoming connection is available call HandleConnection.
+ */
+ void Listen();
+
+ /*
+ * Launch parser if incoming connection is acceptable.
+ */
+ void HandleConnection();
+
+ /*
+ * Perform socket preparation to launch the HTTP WEB server.
+ */
+ bool WEBServerInit();
+
+ std::string state_;
+ std::string port_;
+ int timeout_;
+ std::string code_;
+ std::thread thread_;
+ Selector selector_;
+ Parser parser_;
+ Socket listen_socket_;
+ int listen_port_;
+ int connections_counter_;
+ std::atomic listening_;
+ bool cancel_ = false;
+};
+
+#endif // WEBSERVER_H_
diff --git a/driver/plugin/federated/http/WEBServer_utils.cpp b/driver/plugin/federated/http/WEBServer_utils.cpp
new file mode 100644
index 0000000..49c4f2a
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer_utils.cpp
@@ -0,0 +1,40 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "WEBServer_utils.h"
+
+int WebServerUtils::GenerateRandomInteger(int low, int high)
+{
+ std::random_device rd;
+ std::mt19937 generator(rd());
+ std::uniform_int_distribution<> dist(low, high);
+
+ return dist(generator);
+}
+
+std::string WebServerUtils::GenerateState()
+{
+ const char STATE_CHAR_LIST[37] = "0123456789abcdefghijklmnopqrstuvwxyz";
+ const int chars_size = (sizeof(STATE_CHAR_LIST) / sizeof(*STATE_CHAR_LIST)) - 1;
+ const int rand_size = GenerateRandomInteger(9, chars_size - 1);
+ std::string state;
+
+ state.reserve(rand_size);
+
+ for (int i = 0; i < rand_size; ++i) {
+ state.push_back(STATE_CHAR_LIST[GenerateRandomInteger(0, rand_size)]);
+ }
+
+ return state;
+}
diff --git a/driver/plugin/federated/http/WEBServer_utils.h b/driver/plugin/federated/http/WEBServer_utils.h
new file mode 100644
index 0000000..d159738
--- /dev/null
+++ b/driver/plugin/federated/http/WEBServer_utils.h
@@ -0,0 +1,25 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef WEBSERVER_UTILS_H_
+#define WEBSERVER_UTILS_H_
+
+#include
+
+namespace WebServerUtils {
+ int GenerateRandomInteger(int low, int high);
+ std::string GenerateState();
+}
+
+#endif // WEBSERVER_UTILS_H_
diff --git a/driver/plugin/federated/okta_auth_plugin.cpp b/driver/plugin/federated/okta_auth_plugin.cpp
index 95c66ce..23028ee 100644
--- a/driver/plugin/federated/okta_auth_plugin.cpp
+++ b/driver/plugin/federated/okta_auth_plugin.cpp
@@ -12,17 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "okta_auth_plugin.h"
+#if (defined(_WIN32) || defined(_WIN64))
+ #include
+ #include
+ #undef GetObject
+#endif
-#include
+#include "okta_auth_plugin.h"
#include "html_util.h"
+#include "http/WEBServer.h"
+#include "http/WEBServer_utils.h"
+#include "saml_util.h"
-#include "../../util/aws_sdk_helper.h"
#include "../../util/connection_string_keys.h"
#include "../../util/logger_wrapper.h"
+#include "../../util/rds_strings.h"
#include "../../util/rds_utils.h"
-#include "saml_util.h"
OktaAuthPlugin::OktaAuthPlugin(DBC *dbc) : OktaAuthPlugin(dbc, nullptr) {}
@@ -113,7 +119,7 @@ SQLRETURN OktaAuthPlugin::Connect(
ret = next_plugin->Connect(ConnectionHandle, WindowHandle, OutConnectionString, BufferLength, StringLengthPtr, DriverCompletion);
// Unsuccessful connection using cached token
- // Skip cache and generate a new token to retry
+ // Skip cache and generate a new token to retry
if (!SQL_SUCCEEDED(ret) && token.second) {
LOG(WARNING) << "Cached token failed to connect. Retrying with fresh token";
// Update AWS Credentials
@@ -145,6 +151,16 @@ OktaSamlUtil::OktaSamlUtil(
}
sign_in_url = "https://" + idp_endpoint + ":" + idp_port + "/app/amazon_aws/" + app_id + "/sso/saml" + "?onetimetoken=";
session_token_url = "https://" + idp_endpoint + ":" + idp_port + "/api/v1/authn";
+
+ const std::string mfa_type_str = connection_attributes.contains(KEY_MFA_TYPE) ?
+ connection_attributes.at(KEY_MFA_TYPE) : "";
+ if (mfa_type_table.contains(mfa_type_str)) {
+ mfa_type = mfa_type_table.at(mfa_type_str);
+ }
+ mfa_port = connection_attributes.contains(KEY_MFA_PORT) ?
+ connection_attributes.at(KEY_MFA_PORT) : DEFAULT_PORT;
+ mfa_timeout = connection_attributes.contains(KEY_MFA_TIMEOUT) ?
+ connection_attributes.at(KEY_MFA_TIMEOUT) : DEFAULT_MFA_TIMEOUT;
}
std::string OktaSamlUtil::GetSamlAssertion()
@@ -214,10 +230,169 @@ std::string OktaSamlUtil::GetSessionToken()
LOG(ERROR) << "Unable to parse JSON from response";
return "";
}
+
+ if (mfa_type != NONE) {
+ const Aws::Utils::Json::JsonView json_view = json_val.View();
+
+ if (!json_view.KeyExists("stateToken")) {
+ LOG(ERROR) << "Could not find state token in JSON";
+ return "";
+ }
+
+ const std::string state_token = json_view.GetString("stateToken");
+ const Aws::Utils::Json::JsonView embedded_view = json_view.GetObject("_embedded");
+ Aws::Utils::Array factor_views = embedded_view.GetArray("factors");
+
+ std::string factor_id;
+ for (int i = 0; i < factor_views.GetLength(); i++) {
+ const std::string type = factor_views[i].GetString("factorType");
+ if (mfa_type == TOTP && type == "token:software:totp" || mfa_type == PUSH && type == "push") {
+ factor_id = factor_views[i].GetString("id");
+ }
+ }
+
+ if (factor_id.empty()) {
+ LOG(ERROR) << "Could not find factor in JSON";
+ return "";
+ }
+
+ const std::string verify_url = session_token_url + "/factors/" + factor_id + "/verify";
+ if (mfa_type == TOTP) {
+ return VerifyTOTPChallenge(verify_url, state_token);
+ }
+ if (mfa_type == PUSH) {
+ return VerifyPushChallenge(verify_url, state_token);
+ }
+ }
+
const Aws::Utils::Json::JsonView json_view = json_val.View();
if (!json_view.KeyExists("sessionToken")) {
LOG(ERROR) << "Could not find session token in JSON";
return "";
}
+
return json_view.GetString("sessionToken");
}
+
+std::string OktaSamlUtil::VerifyTOTPChallenge(
+ const std::string &verify_url,
+ const std::string &state_token)
+{
+ std::string state = WebServerUtils::GenerateState();
+ WEBServer srv(state, mfa_port, mfa_timeout);
+
+ srv.LaunchServer();
+
+ const std::string mfa_form_url = WEBSERVER_HOST + ":" + mfa_port;
+ try {
+#if (defined(_WIN32) || defined(_WIN64))
+ const HINSTANCE result = ShellExecute(NULL, RDS_TSTR(std::string("open")).c_str(), RDS_TSTR(mfa_form_url).c_str(), NULL, NULL, SW_SHOWNORMAL);
+ if (reinterpret_cast(result) <= 32) {
+ srv.Cancel();
+ }
+#else
+#if (defined(LINUX) || defined(__linux__))
+ const int result = system(("xdg-open " + mfa_form_url).c_str());
+#else
+ const int result = system(("open " + mfa_form_url).c_str());
+#endif
+ if (result != 0) {
+ srv.Cancel();
+ }
+#endif
+ } catch (const std::exception & e) {
+ srv.Cancel();
+ srv.Join();
+ LOG(ERROR) << "Could not open browser to obtain MFA token: " << e.what();
+ return "";
+ }
+
+ srv.Join();
+
+ const std::string pass_code = srv.GetCode();
+ if (pass_code.empty()) {
+ LOG(ERROR) << "MFA Authorization code was not obtained";
+ return "";
+ }
+
+ const std::shared_ptr req = Aws::Http::CreateHttpRequest(
+ verify_url, Aws::Http::HttpMethod::HTTP_POST, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
+ Aws::Utils::Json::JsonValue json_body;
+ json_body
+ .WithString("stateToken", state_token)
+ .WithString("passCode", pass_code);
+ const Aws::String json_str = json_body.View().WriteReadable();
+ const Aws::String json_len = Aws::Utils::StringUtils::to_string(json_str.size());
+ req->SetContentType("application/json");
+ req->AddContentBody(Aws::MakeShared("", json_str));
+ req->SetContentLength(json_len);
+ const std::shared_ptr response = http_client->MakeRequest(req);
+
+ // Check resp status
+ if (response->GetResponseCode() != Aws::Http::HttpResponseCode::OK) {
+ LOG(ERROR) << "OKTA request returned bad HTTP response code: " << response->GetResponseCode();
+ if (response->HasClientError()) {
+ LOG(ERROR) << "HTTP Client Error: " << response->GetClientErrorMessage();
+ }
+ return "";
+ }
+
+ const Aws::Utils::Json::JsonValue json_val(response->GetResponseBody());
+ if (!json_val.WasParseSuccessful()) {
+ LOG(ERROR) << "Unable to parse JSON from response";
+ return "";
+ }
+
+ const Aws::Utils::Json::JsonView json_view = json_val.View();
+ if (!json_view.KeyExists("sessionToken")) {
+ LOG(ERROR) << "Could not find session token in JSON";
+ return "";
+ }
+
+ return json_view.GetString("sessionToken");
+}
+
+std::string OktaSamlUtil::VerifyPushChallenge(
+ const std::string &verify_url,
+ const std::string &state_token)
+{
+ const std::chrono::time_point end_time = std::chrono::system_clock::now() + std::chrono::seconds(std::strtol(mfa_timeout.c_str(), nullptr, 0));
+ while (std::chrono::system_clock::now() < end_time) {
+ const std::shared_ptr req = Aws::Http::CreateHttpRequest(
+ verify_url, Aws::Http::HttpMethod::HTTP_POST, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
+ Aws::Utils::Json::JsonValue json_body;
+ json_body.WithString("stateToken", state_token);
+ const Aws::String json_str = json_body.View().WriteReadable();
+ const Aws::String json_len = Aws::Utils::StringUtils::to_string(json_str.size());
+ req->SetContentType("application/json");
+ req->AddContentBody(Aws::MakeShared("", json_str));
+ req->SetContentLength(json_len);
+ const std::shared_ptr response = http_client->MakeRequest(req);
+
+ // Check resp status
+ if (response->GetResponseCode() != Aws::Http::HttpResponseCode::OK) {
+ LOG(ERROR) << "OKTA request returned bad HTTP response code: " << response->GetResponseCode();
+ if (response->HasClientError()) {
+ LOG(ERROR) << "HTTP Client Error: " << response->GetClientErrorMessage();
+ }
+ } else {
+ const Aws::Utils::Json::JsonValue json_val(response->GetResponseBody());
+ if (!json_val.WasParseSuccessful()) {
+ LOG(ERROR) << "Unable to parse JSON from response";
+ continue;
+ }
+
+ const Aws::Utils::Json::JsonView json_view = json_val.View();
+ if (!json_view.KeyExists("sessionToken")) {
+ LOG(ERROR) << "Could not find session token in JSON";
+ continue;
+ }
+
+ return json_view.GetString("sessionToken");
+ }
+
+ std::this_thread::sleep_for(std::chrono::seconds(VERIFY_PUSH_INTERVAL));
+ }
+ LOG(ERROR) << "The MFA challenge was not completed in time";
+ return "";
+}
diff --git a/driver/plugin/federated/okta_auth_plugin.h b/driver/plugin/federated/okta_auth_plugin.h
index d0239a4..8f4d44c 100644
--- a/driver/plugin/federated/okta_auth_plugin.h
+++ b/driver/plugin/federated/okta_auth_plugin.h
@@ -21,6 +21,18 @@
#include "../base_plugin.h"
#include "../../driver.h"
+typedef enum {
+ NONE,
+ TOTP,
+ PUSH
+} MfaType;
+
+static std::map const mfa_type_table = {
+ {"", MfaType::NONE},
+ {VALUE_MFA_TOTP, MfaType::TOTP},
+ {VALUE_MFA_PUSH, MfaType::PUSH}
+};
+
class OktaSamlUtil : public SamlUtil {
public:
OktaSamlUtil(const std::map &connection_attributes);
@@ -30,8 +42,19 @@ class OktaSamlUtil : public SamlUtil {
private:
std::string GetSessionToken();
+ std::string VerifyTOTPChallenge(const std::string &verify_url, const std::string &state_token);
+
+ std::string VerifyPushChallenge(const std::string &verify_url, const std::string &state_token);
+
+ static inline const std::string DEFAULT_MFA_TIMEOUT = "60";
+ static inline const int VERIFY_PUSH_INTERVAL = 5;
+ static inline const std::string DEFAULT_PORT = "8080";
+ static inline const std::string WEBSERVER_HOST = "http://127.0.0.1";
std::string sign_in_url;
std::string session_token_url;
+ MfaType mfa_type;
+ std::string mfa_port;
+ std::string mfa_timeout;
static inline const std::regex SAML_RESPONSE_PATTERN = std::regex("");
};
diff --git a/driver/util/attribute_validator.cpp b/driver/util/attribute_validator.cpp
index 4b68941..275fa55 100644
--- a/driver/util/attribute_validator.cpp
+++ b/driver/util/attribute_validator.cpp
@@ -34,7 +34,9 @@ bool AttributeValidator::ShouldKeyBeUnsignedInt(const std::string& key) {
KEY_FAILOVER_TIMEOUT,
KEY_LIMITLESS_MONITOR_INTERVAL_MS,
KEY_ROUTER_MAX_RETRIES,
- KEY_LIMITLESS_MAX_RETRIES
+ KEY_LIMITLESS_MAX_RETRIES,
+ KEY_MFA_PORT,
+ KEY_MFA_TIMEOUT
};
return INTEGER_KEYS.contains(key);
}
diff --git a/driver/util/connection_string_helper.h b/driver/util/connection_string_helper.h
index f97676a..a607a0d 100644
--- a/driver/util/connection_string_helper.h
+++ b/driver/util/connection_string_helper.h
@@ -48,6 +48,9 @@ static std::unordered_set const aws_odbc_key_set = {
KEY_HTTP_CONNECT_TIMEOUT,
KEY_RELAY_PARTY_ID,
KEY_APP_ID,
+ KEY_MFA_TYPE,
+ KEY_MFA_PORT,
+ KEY_MFA_TIMEOUT,
KEY_DATABASE_DIALECT,
KEY_HOST_SELECTOR_STRATEGY,
KEY_ENABLE_FAILOVER,
diff --git a/driver/util/connection_string_keys.h b/driver/util/connection_string_keys.h
index 435d570..f716eb5 100644
--- a/driver/util/connection_string_keys.h
+++ b/driver/util/connection_string_keys.h
@@ -75,6 +75,11 @@
#define KEY_RELAY_PARTY_ID "RELAY_PARTY_ID"
/* OKTA */
#define KEY_APP_ID "APP_ID"
+#define KEY_MFA_TYPE "MFA_TYPE"
+#define KEY_MFA_PORT "MFA_PORT"
+#define KEY_MFA_TIMEOUT "MFA_TIMEOUT"
+#define VALUE_MFA_TOTP "TOTP"
+#define VALUE_MFA_PUSH "PUSH"
/* Host Selectors */
#define KEY_HOST_SELECTOR_STRATEGY "HOST_SELECTOR_STRATEGY"