diff --git a/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md b/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md index 1cda4ed..7e57c6f 100644 --- a/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md +++ b/docs/using-the-aws-odbc-wrapper/plugins/okta-authentication-plugin.md @@ -14,29 +14,33 @@ When a user wants access to a resource, it authenticates with the IdP. From this 1. Follow steps in [Enable AWS IAM Database Authentication](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.Enabling.html) to setup IAM authentication. 2. Configure Okta as the AWS identity provider following [Okta's official documentation](https://help.okta.com/en-us/content/topics/deploymentguides/aws/aws-deployment.htm) +3. (Optional) Enable MFA. MFA through Okta Verify is supported for the Push and OTP methods. Please ensure the authentication policies and/or global session policies have been configured to use MFA. ### Connection String / DSN Configuration for Okta Authentication Plugin Support -| Field | Connection Option Key | Value | Default Value | Sample Value | -|-----------------------|------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|--------------------------------------------------------| -| Authentication Type | `RDS_AUTH_TYPE` | Must be `OKTA`. | `database` | `OKTA` | -| Server | `SERVER` | Database instance server host. | nil | `database.us-east-1-rds.amazon.com` | -| Port | `PORT` | Port that the database is listening on. | nil | 5432 | -| User Name | `UID` | Database user name for IAM authentication. | nil | `iam_user` | -| IAM Host | `IAM_HOST` | The endpoint used to generate the authentication token. This is only required if you are connecting using custom endpoints such as an IP address. | nil | `database.us-east-1-rds.amazon.com` | -| Region | `REGION` | The region of the database for IAM authentication. | `us-east-1` | `us-east-1` | -| Database | `DATABASE` | Default database that a user will work on. | nil | `my_database` | -| Token Expiration | `TOKEN_EXPIRATION` | Token expiration in seconds, supported max value is 900. | 900 | 900 | -| IdP Endpoint | `IDP_ENDPOINT` | The ADFS host that is used to authenticate with. | nil | `my-adfs-host.com` | -| IdP Port | `IDP_PORT` | The ADFS host port. | 443 | 443 | -| IdP User Name | `IDP_USERNAME` | The user name for the IdP Endpoint server. | nil | `user@email.com` | -| IdP Password | `IDP_PASSWORD` | The IdP user's password. | nil | `my_password_123` | -| Role ARN | `IDP_ROLE_ARN` | The ARN of the IAM Role that is to be assumed for database access. | nil | `arn:aws:iam::123412341234:role/ADFS-SAML-Assume` | -| IdP SAML Provider ARN | `IDP_SAML_ARN` | The ARN of the Identity Provider. | nil | `arn:aws:iam::123412341234:saml-provider/ADFS-AWS-IAM` | -| HTTP Socket Timeout | `HTTP_SOCKET_TIMEOUT` | The socket timeout value in milliseconds for the HttpClient reading. | 3000 | 3000 | -| HTTP Connect Timeout | `HTTP_CONNECT_TIMEOUT` | The connect timeout value in milliseconds for the HttpClient. | 5000 | 5000 | -| App ID | `APP_ID` | The application ID for AWS configured on. | nil | `my-app-id` | -| Extra URL Encode | `EXTRA_URL_ENCODE` | Generated tokens can have URL encoding prefix duplication for scenarios where underlying drivers automatically decode the URL before passing to the database for connections. | `0` | `1` | +| Field | Connection Option Key | Value | Default Value | Sample Value | +|-----------------------|------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|--------------------------------------------------------| +| Authentication Type | `RDS_AUTH_TYPE` | Must be `OKTA`. | `database` | `OKTA` | +| Server | `SERVER` | Database instance server host. | nil | `database.us-east-1-rds.amazon.com` | +| Port | `PORT` | Port that the database is listening on. | nil | `5432` | +| User Name | `UID` | Database user name for IAM authentication. | nil | `iam_user` | +| IAM Host | `IAM_HOST` | The endpoint used to generate the authentication token. This is only required if you are connecting using custom endpoints such as an IP address. | nil | `database.us-east-1-rds.amazon.com` | +| Region | `REGION` | The region of the database for IAM authentication. | `us-east-1` | `us-east-1` | +| Database | `DATABASE` | Default database that a user will work on. | nil | `my_database` | +| Token Expiration | `TOKEN_EXPIRATION` | Token expiration in seconds, supported max value is 900. | `900` | `900` | +| IdP Endpoint | `IDP_ENDPOINT` | The ADFS host that is used to authenticate with. | nil | `my-adfs-host.com` | +| IdP Port | `IDP_PORT` | The ADFS host port. | `443` | `443` | +| IdP User Name | `IDP_USERNAME` | The user name for the IdP Endpoint server. | nil | `user@email.com` | +| IdP Password | `IDP_PASSWORD` | The IdP user's password. | nil | `my_password_123` | +| Role ARN | `IDP_ROLE_ARN` | The ARN of the IAM Role that is to be assumed for database access. | nil | `arn:aws:iam::123412341234:role/ADFS-SAML-Assume` | +| IdP SAML Provider ARN | `IDP_SAML_ARN` | The ARN of the Identity Provider. | nil | `arn:aws:iam::123412341234:saml-provider/ADFS-AWS-IAM` | +| HTTP Socket Timeout | `HTTP_SOCKET_TIMEOUT` | The socket timeout value in milliseconds for the HttpClient reading. | `3000` | `3000` | +| HTTP Connect Timeout | `HTTP_CONNECT_TIMEOUT` | The connect timeout value in milliseconds for the HttpClient. | `5000` | `5000` | +| App ID | `APP_ID` | The application ID for AWS configured on. | nil | `my-app-id` | +| Extra URL Encode | `EXTRA_URL_ENCODE` | Generated tokens can have URL encoding prefix duplication for scenarios where underlying drivers automatically decode the URL before passing to the database for connections. | `0` | `1` | +| MFA Type | `MFA_TYPE` | The MFA type the user specifies. The available options are: `TOTP`, `PUSH`. **Note**: the `TOTP` type requires a web browser to be used. | nil | `TOTP` | +| MFA Port | `MFA_PORT` | The port used to connect to `127.0.0.1` to provide the one time code when using TOTP as the MFA Type. | `8080` | `8000` | +| MFA Timeout | `MFA_TIMEOUT` | The time in seconds to complete the MFA challenge before the connection fails. | `60` | `30` | > [!WARNING]\ > Using IAM Authentication, connections to the database must have SSL enabled. Please refer to the underlying driver's specifications to enable this. diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 07e6bbd..b478067 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -74,6 +74,17 @@ set(INC ${CMAKE_CURRENT_SOURCE_DIR}/host_info.h ${CMAKE_CURRENT_SOURCE_DIR}/odbcapi.h ${CMAKE_CURRENT_SOURCE_DIR}/odbcapi_rds_helper.h + + # Webserver + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/AddrInformation.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/HtmlResponse.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Selector.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Socket.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/SocketStream.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/SocketSupport.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer.h + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer_utils.h ) set(SRC @@ -116,6 +127,15 @@ set(SRC ${CMAKE_CURRENT_SOURCE_DIR}/host_info.cpp ${CMAKE_CURRENT_SOURCE_DIR}/odbcapi_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/odbcapi_rds_helper.cpp + + # Webserver + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/AddrInformation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Parser.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Selector.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/Socket.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/SocketStream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/plugin/federated/http/WEBServer_utils.cpp ) # GUI diff --git a/driver/gui/odbcsetup.rc b/driver/gui/odbcsetup.rc index 637c65d..6581987 100644 --- a/driver/gui/odbcsetup.rc +++ b/driver/gui/odbcsetup.rc @@ -148,43 +148,49 @@ STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN COMBOBOX IDC_AUTH_MODE,62,7,96,30,CBS_DROPDOWN | CBS_SORT | WS_VSCROLL | WS_TABSTOP - EDITTEXT IDC_REGION,62,24,96,14,ES_AUTOHSCROLL - EDITTEXT IDC_EXPIRE,62,44,96,14,ES_AUTOHSCROLL - EDITTEXT IDC_IAM_HOST,62,82,96,14,ES_AUTOHSCROLL | WS_DISABLED - EDITTEXT IDC_IAM_PORT,62,101,96,14,ES_AUTOHSCROLL | WS_DISABLED - EDITTEXT IDC_SECRET,62,120,96,14,ES_AUTOHSCROLL | WS_DISABLED - EDITTEXT IDC_SECRET_REGION,62,139,96,14,ES_AUTOHSCROLL | WS_DISABLED - EDITTEXT IDC_SECRET_END,62,159,96,14,ES_AUTOHSCROLL | WS_DISABLED + EDITTEXT IDC_REGION,62,22,96,14,ES_AUTOHSCROLL + EDITTEXT IDC_EXPIRE,62,39,96,14,ES_AUTOHSCROLL + EDITTEXT IDC_IAM_HOST,62,56,96,14,ES_AUTOHSCROLL | WS_DISABLED + EDITTEXT IDC_IAM_PORT,62,73,96,14,ES_AUTOHSCROLL | WS_DISABLED + EDITTEXT IDC_SECRET,62,90,96,14,ES_AUTOHSCROLL | WS_DISABLED + EDITTEXT IDC_SECRET_REGION,62,107,96,14,ES_AUTOHSCROLL | WS_DISABLED + EDITTEXT IDC_SECRET_END,62,124,96,14,ES_AUTOHSCROLL | WS_DISABLED EDITTEXT IDC_IDP_UID,209,7,106,14,ES_AUTOHSCROLL - EDITTEXT IDC_IDP_PWD,209,25,106,14,ES_PASSWORD | ES_AUTOHSCROLL - EDITTEXT IDC_IDP_END,209,45,106,14,ES_AUTOHSCROLL - EDITTEXT IDC_APP_ID,209,65,106,14,ES_AUTOHSCROLL | WS_DISABLED - EDITTEXT IDC_ROLE_ARN,209,85,106,14,ES_AUTOHSCROLL - EDITTEXT IDC_IDP_ARN,209,104,106,14,ES_AUTOHSCROLL - EDITTEXT IDC_IDP_PORT,209,124,106,14,ES_AUTOHSCROLL + EDITTEXT IDC_IDP_PWD,209,24,106,14,ES_PASSWORD | ES_AUTOHSCROLL + EDITTEXT IDC_IDP_END,209,42,106,14,ES_AUTOHSCROLL + EDITTEXT IDC_APP_ID,209,146,106,14,ES_AUTOHSCROLL | WS_DISABLED + EDITTEXT IDC_ROLE_ARN,209,59,106,14,ES_AUTOHSCROLL + EDITTEXT IDC_IDP_ARN,209,77,106,14,ES_AUTOHSCROLL + EDITTEXT IDC_IDP_PORT,209,94,106,14,ES_AUTOHSCROLL RTEXT "Authentication Mode:",IDC_STATIC,10,5,47,18,0,WS_EX_RIGHT - RTEXT "IAM Host:",IDC_STATIC,23,84,33,8,0,WS_EX_RIGHT - RTEXT "IAM Port:",IDC_STATIC,25,103,32,8,0,WS_EX_RIGHT - RTEXT "Expire Time (ms):",IDC_STATIC,17,42,40,16,0,WS_EX_RIGHT - LTEXT "Secret ID:",IDC_STATIC,23,122,34,8,0,WS_EX_RIGHT - RTEXT "Secret Endpoint:",IDC_STATIC,25,156,32,17,0,WS_EX_RIGHT - LTEXT "Secret Region:",IDC_STATIC,6,140,51,8,0,WS_EX_RIGHT + RTEXT "IAM Host:",IDC_STATIC,23,58,33,8,0,WS_EX_RIGHT + RTEXT "IAM Port:",IDC_STATIC,25,75,32,8,0,WS_EX_RIGHT + RTEXT "Expire Time (ms):",IDC_STATIC,17,37,40,16,0,WS_EX_RIGHT + LTEXT "Secret ID:",IDC_STATIC,23,92,34,8,0,WS_EX_RIGHT + RTEXT "Secret Endpoint:",IDC_STATIC,25,121,32,17,0,WS_EX_RIGHT + LTEXT "Secret Region:",IDC_STATIC,6,108,51,8,0,WS_EX_RIGHT RTEXT "IDP Username:",IDC_STATIC,166,5,39,17,0,WS_EX_RIGHT - RTEXT "IDP Password:",IDC_STATIC,168,23,37,18,0,WS_EX_RIGHT - RTEXT "IDP Endpoint:",IDC_STATIC,169,44,36,16,0,WS_EX_RIGHT - LTEXT "App ID:",IDC_STATIC,179,67,26,8,0,WS_EX_RIGHT - RTEXT "IAM Role ARN:",IDC_STATIC,171,83,34,16,0,WS_EX_RIGHT - RTEXT "IDP SAML ARN:",IDC_STATIC,173,103,32,15,0,WS_EX_RIGHT - RTEXT "IDP Port:",IDC_STATIC,175,126,30,8,0,WS_EX_RIGHT - CONTROL "",IDC_URL_ENCODE,"Button",BS_AUTOCHECKBOX | BS_LEFTTEXT | BS_MULTILINE | WS_TABSTOP,62,62,10,17 - LTEXT "Extra URL Encoding:",IDC_STATIC,21,61,36,18,0,WS_EX_RIGHT - LTEXT "AWS Region:",IDC_STATIC,13,27,44,8,0,WS_EX_RIGHT - EDITTEXT IDC_RELAY_PARTY_ID,209,144,106,14,ES_AUTOHSCROLL - RTEXT "Relaying Party ID:",IDC_STATIC,172,142,33,16 - EDITTEXT IDC_CONNECT_TIMEOUT,209,164,106,14,ES_AUTOHSCROLL - EDITTEXT IDC_SOCKET_TIMEOUT,62,178,96,14,ES_AUTOHSCROLL - RTEXT "Connect Timeout:",IDC_STATIC,163,163,42,17 - RTEXT "Socket Timeout:",IDC_STATIC,18,177,39,17 + RTEXT "IDP Password:",IDC_STATIC,168,22,37,18,0,WS_EX_RIGHT + RTEXT "IDP Endpoint:",IDC_STATIC,169,41,36,16,0,WS_EX_RIGHT + LTEXT "App ID:",IDC_STATIC,179,148,26,8,0,WS_EX_RIGHT + RTEXT "IAM Role ARN:",IDC_STATIC,171,57,34,16,0,WS_EX_RIGHT + RTEXT "IDP SAML ARN:",IDC_STATIC,173,76,32,15,0,WS_EX_RIGHT + RTEXT "IDP Port:",IDC_STATIC,175,96,30,8,0,WS_EX_RIGHT + CONTROL "",IDC_URL_ENCODE,"Button",BS_AUTOCHECKBOX | BS_LEFTTEXT | BS_MULTILINE | WS_TABSTOP,61,178,10,17 + LTEXT "Extra URL Encoding:",IDC_STATIC,21,177,36,18,0,WS_EX_RIGHT + LTEXT "AWS Region:",IDC_STATIC,13,25,44,8,0,WS_EX_RIGHT + EDITTEXT IDC_RELAY_PARTY_ID,209,112,106,14,ES_AUTOHSCROLL + RTEXT "Relaying Party ID:",IDC_STATIC,172,110,33,16 + EDITTEXT IDC_CONNECT_TIMEOUT,209,129,106,14,ES_AUTOHSCROLL + EDITTEXT IDC_SOCKET_TIMEOUT,62,141,96,14,ES_AUTOHSCROLL + RTEXT "Connect Timeout:",IDC_STATIC,163,128,42,17 + RTEXT "Socket Timeout:",IDC_STATIC,18,140,39,17 + COMBOBOX IDC_MFA_TYPE,62,161,96,30,CBS_DROPDOWN | CBS_SORT | WS_VSCROLL | WS_TABSTOP + LTEXT "MFA Type:",IDC_STATIC,32,159,25,19,0,WS_EX_RIGHT + EDITTEXT IDC_MFA_PORT,209,163,106,14,ES_AUTOHSCROLL + LTEXT "MFA Server Port:",IDC_STATIC,164,161,41,17,0,WS_EX_RIGHT + LTEXT "MFA Timeout (s):",IDC_STATIC,158,179,47,16,0,WS_EX_RIGHT + EDITTEXT IDC_MFA_TIMEOUT,209,181,106,14,ES_AUTOHSCROLL END IDC_TAB_FAILOVER DIALOGEX 5, 15, 322, 201 @@ -260,7 +266,7 @@ BEGIN VERTGUIDE, 205 VERTGUIDE, 209 TOPMARGIN, 7 - BOTTOMMARGIN, 194 + BOTTOMMARGIN, 195 END IDC_TAB_FAILOVER, DIALOG diff --git a/driver/gui/resource.h b/driver/gui/resource.h index 99aa7db..14ee904 100644 --- a/driver/gui/resource.h +++ b/driver/gui/resource.h @@ -75,6 +75,10 @@ #define IDC_CONNECT_TIMEOUT 1059 #define IDC_SOCKET_TIMEOUT 1061 #define IDC_DB_DIALECT 1062 +#define IDC_MFA_PORT 1063 +#define IDC_MFA_TIMEOUT 1064 +#define IDC_MFA_TYPE 1065 + // Next default values for new objects // @@ -82,7 +86,7 @@ #ifndef APSTUDIO_READONLY_SYMBOLS #define _APS_NEXT_RESOURCE_VALUE 126 #define _APS_NEXT_COMMAND_VALUE 40001 -#define _APS_NEXT_CONTROL_VALUE 1063 +#define _APS_NEXT_CONTROL_VALUE 1066 #define _APS_NEXT_SYMED_VALUE 101 #endif #endif diff --git a/driver/gui/setup.cpp b/driver/gui/setup.cpp index e29b561..efd9aa3 100644 --- a/driver/gui/setup.cpp +++ b/driver/gui/setup.cpp @@ -107,7 +107,10 @@ const std::map> FED_AUTH_KEYS = { {KEY_IDP_PORT, {IDC_IDP_PORT, EDIT_TEXT}}, {KEY_RELAY_PARTY_ID, {IDC_RELAY_PARTY_ID, EDIT_TEXT}}, {KEY_HTTP_CONNECT_TIMEOUT, {IDC_CONNECT_TIMEOUT, EDIT_TEXT}}, - {KEY_HTTP_SOCKET_TIMEOUT, {IDC_SOCKET_TIMEOUT, EDIT_TEXT}} + {KEY_HTTP_SOCKET_TIMEOUT, {IDC_SOCKET_TIMEOUT, EDIT_TEXT}}, + {KEY_MFA_TYPE, {IDC_MFA_TYPE, COMBO}}, + {KEY_MFA_PORT, {IDC_MFA_PORT, EDIT_TEXT}}, + {KEY_MFA_TIMEOUT, {IDC_MFA_TIMEOUT, EDIT_TEXT}} }; const std::map> FAILOVER_KEYS = { @@ -164,6 +167,12 @@ const std::vector> DB_DIALECTS = { {"Aurora PostgreSQL Limitless", VALUE_DB_DIALECT_AURORA_POSTGRESQL_LIMITLESS} }; +const std::vector> MFA_TYPES = { + {"None", ""}, + {"TOTP", VALUE_MFA_TOTP}, + {"Push", VALUE_MFA_PUSH} +}; + HINSTANCE ghInstance; HWND tab_control; HWND aws_auth_tab; @@ -252,6 +261,8 @@ std::string GetControlValue(HWND hwnd, std::pair pair) return LIMITLESS_MODES[selection].second; case IDC_DB_DIALECT: return DB_DIALECTS[selection].second; + case IDC_MFA_TYPE: + return MFA_TYPES[selection].second; default: break; } @@ -441,7 +452,6 @@ void TestConnection(HWND hwnd) RDS_AllocDbc(henv, &hdbc); std::string test_conn_str = GetDsn(true); - SQLRETURN ret = RDS_SQLDriverConnect( hdbc, nullptr, @@ -712,7 +722,10 @@ void HandleAuthModeSelection(HWND hwnd) { if (!IAM_KEYS.contains(keys.first) && !FED_AUTH_KEYS.contains(keys.first) && !AUTH_KEYS.contains(keys.first) || - id == IDC_APP_ID ) { + id == IDC_APP_ID || + id == IDC_MFA_TYPE || + id == IDC_MFA_PORT || + id == IDC_MFA_TIMEOUT) { show_ctrl = false; } break; @@ -761,7 +774,13 @@ BOOL AuthTabInit(HWND hwnd, HWND hwndFocus, LPARAM lParam) ComboBox_InsertString(auth_mode_box, i, RDS_TSTR(AWS_AUTH_MODES[i].first).c_str()); } + HWND mfa_type_box = GetDlgItem(hwnd, IDC_MFA_TYPE); + for (int i = 0; i < MFA_TYPES.size(); i++) { + ComboBox_InsertString(mfa_type_box, i, RDS_TSTR(MFA_TYPES[i].first).c_str()); + } + SetInitialComboBoxValue(hwnd, IDC_AUTH_MODE, KEY_AUTH_TYPE, AWS_AUTH_MODES); + SetInitialComboBoxValue(hwnd, IDC_MFA_TYPE, KEY_MFA_TYPE, MFA_TYPES); HandleAuthModeSelection(hwnd); diff --git a/driver/plugin/federated/http/AddrInformation.cpp b/driver/plugin/federated/http/AddrInformation.cpp new file mode 100755 index 0000000..59bcf6b --- /dev/null +++ b/driver/plugin/federated/http/AddrInformation.cpp @@ -0,0 +1,120 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "AddrInformation.h" + +#include +#include + +AddrInformation::AddrInformation( const std::string& port) +{ + addrinfo hints; + + memset(&hints, 0, sizeof(hints)); + + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + if (getaddrinfo("127.0.0.1", port.c_str(), &hints, &addrinfo_) != 0) { + throw std::runtime_error("Unable to get address information."); + } +} + +AddrInformation::~AddrInformation() +{ + freeaddrinfo(addrinfo_); +} + +AddrInformationIterator AddrInformation::begin() +{ + return AddrInformationIterator(addrinfo_); +} + +AddrInformationIterator AddrInformation::end() +{ + return AddrInformationIterator(nullptr); +} + +AddrInformationIterator AddrInformation::begin() const +{ + return AddrInformationIterator(addrinfo_); +} + +AddrInformationIterator AddrInformation::end() const +{ + return AddrInformationIterator(nullptr); +} + +AddrInformationIterator::AddrInformationIterator(addrinfo* addr) : addr_(addr) +{ + ; // Do nothing. +} + +AddrInformationIterator::AddrInformationIterator(const AddrInformationIterator& itr) + : addr_(itr.addr_) +{ + ; // Do nothing. +} + +AddrInformationIterator::~AddrInformationIterator() +{ + ; // Do nothing. +} + +AddrInformationIterator AddrInformationIterator::operator++(int) +{ + AddrInformationIterator tmp(*this); + + if (addr_) + { + addr_ = addr_->ai_next; + } + + return tmp; +} + +AddrInformationIterator& AddrInformationIterator::operator++() +{ + if (addr_) { + addr_ = addr_->ai_next; + } + + return *this; +} + +bool AddrInformationIterator::operator!=(const AddrInformationIterator& itr) const +{ + return addr_ != itr.addr_; +} + +bool AddrInformationIterator::operator==(const AddrInformationIterator& itr) const +{ + return addr_ == itr.addr_; +} + +addrinfo* AddrInformationIterator::operator->() +{ + return addr_; +} + +const addrinfo* AddrInformationIterator::operator*() const +{ + return addr_; +} + +addrinfo* AddrInformationIterator::operator*() +{ + return addr_; +} diff --git a/driver/plugin/federated/http/AddrInformation.h b/driver/plugin/federated/http/AddrInformation.h new file mode 100755 index 0000000..f20de13 --- /dev/null +++ b/driver/plugin/federated/http/AddrInformation.h @@ -0,0 +1,73 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "SocketSupport.h" +#include + +/* +* This class is used to support range-based for loop and iterators +* in AddrInformation class. +*/ +class AddrInformationIterator { +public: + AddrInformationIterator(addrinfo *addr); + + AddrInformationIterator(const AddrInformationIterator& itr); + + ~AddrInformationIterator(); + + AddrInformationIterator& operator++(); + + AddrInformationIterator operator++(int); + + bool operator!=(const AddrInformationIterator& itr) const; + + bool operator==(const AddrInformationIterator& itr) const; + + addrinfo* operator->(); + + const addrinfo* operator*() const; + + addrinfo* operator*(); + +private: + addrinfo *addr_; + +}; + +/* +* This class is used to perform network address and service +* translation via getaddrinfo call. Returns one or more addrinfo structures, each +* of which contains an Internet address. +*/ +class AddrInformation +{ +public: + AddrInformation( const std::string& port); + + ~AddrInformation(); + + AddrInformationIterator begin(); + + AddrInformationIterator end(); + + AddrInformationIterator begin() const; + + AddrInformationIterator end() const; + +private: + addrinfo *addrinfo_; +}; diff --git a/driver/plugin/federated/http/HtmlResponse.h b/driver/plugin/federated/http/HtmlResponse.h new file mode 100755 index 0000000..bcfd15c --- /dev/null +++ b/driver/plugin/federated/http/HtmlResponse.h @@ -0,0 +1,44 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +inline std::string GetResponse = + "HTTP/1.1 200 OK\r\n" + "Content-Length: 300\r\n" + "Connection: close\r\n" + "Content-Type: text/html; charset=utf-8\r\n\r\n" + "

AWS Advanced ODBC Wrapper Okta Plugin MFA Authentication Code

" + "
" + "
" + "
" + "" + "
"; + +const std::string ValidResponse = + "HTTP/1.1 200 OK\r\n" + "Content-Length: 300\r\n" + "Connection: close\r\n" + "Content-Type: text/html; charset=utf-8\r\n\r\n" + "

" + "Thank you for using the AWS Advanced ODBC Wrapper! You can now close this window.

"; + +const std::string InvalidResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Content-Length: 95\r\n" + "Connection: close\r\n" + "Content-Type: text/html; charset=utf-8\r\n\r\n" + "The request could not be understood by the server!

"; diff --git a/driver/plugin/federated/http/Parser.cpp b/driver/plugin/federated/http/Parser.cpp new file mode 100755 index 0000000..63195b4 --- /dev/null +++ b/driver/plugin/federated/http/Parser.cpp @@ -0,0 +1,174 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "../../../util/logger_wrapper.h" +#include "Parser.h" + +#include +#include +#include +#include +#include +#include + +void Parser::ParseRequestLine(std::string& str) +{ + std::istringstream istr(str); + std::vector command( + (std::istream_iterator(istr)), + std::istream_iterator() + ); + + if ((command.size() == 3) && (command[0].find(METHOD) != std::string::npos) && + (command[2].find(HTTP_VERSION) != std::string::npos)) { + parser_state_ = STATE::PARSE_HEADER; + } else if ((command.size() == 3) && + (command[0].find(GET_METHOD) != std::string::npos) && + command[2].find(HTTP_VERSION) != std::string::npos) { + const size_t question_mark_pos = command[1].find('?'); + if (question_mark_pos != std::string::npos) { + std::string parsed_body = command[1].substr(question_mark_pos + 1); + ParseBodyLine(parsed_body); + } else { + parser_state_ = STATE::PARSE_GET_REQUEST; + } + } else { + throw std::runtime_error("Request line contains wrong information."); + } +} + +void Parser::ParseHeaderLine(std::string& str) +{ + if (str == "\r") { + parser_state_ = STATE::PARSE_BODY; + + return; + } + + const size_t ind = str.find(':'); + header_size_ += str.size(); + + if ((ind == std::string::npos) || (header_size_ > MAX_HEADER_SIZE)) { + throw std::runtime_error("Received invalid header line."); + } + + str.erase(std::remove_if(str.begin() + static_cast(ind), str.end(), ::isspace), str.end()); + + header_.insert({ + std::string(str.begin(), str.begin() + static_cast(ind)), + std::string(str.begin() + static_cast(ind) + 1, str.end()) + }); +} + +void Parser::ParseBodyLine(std::string& str) +{ + auto str_begin = str.begin(); + auto str_end = str.end(); + + while (true) { + auto equal_it = find(str_begin, str_end, '='); + auto ampersand_it = find(str_begin, str_end, '&'); + + if (equal_it == str_end) { + throw std::runtime_error("Received invalid body line."); + } + + body_.insert({ + std::string(str_begin, equal_it), + std::string(equal_it + 1, ampersand_it) + }); + + if ((equal_it == str_end) || (ampersand_it == str_end)) { + break; + } + + str_begin = ampersand_it + 1; + } + + parser_state_ = STATE::PARSE_FINISHED; +} + +void Parser::ParsePostRequest(std::string& str) +{ + switch (parser_state_) { + case STATE::PARSE_REQUEST: + ParseRequestLine(str); + break; + case STATE::PARSE_HEADER: + ParseHeaderLine(str); + break; + case STATE::PARSE_BODY: + if (!header_.contains("Content-Type") || !header_.contains("Content-Length") || + (header_["Content-Type"] != "application/x-www-form-urlencoded") || + (header_["Content-Length"] != std::to_string(str.size()))) + { + throw std::runtime_error("Can't start parsing body as header contains invalid information."); + } + + ParseBodyLine(str); + break; + default: + break; + } +} + +STATUS Parser::Parse(std::istream &in) +{ + std::string str; + bool is_line_received = false; + + while (getline(in, str) && parser_state_ != STATE::PARSE_GET_REQUEST) { + is_line_received = true; + + try { + ParsePostRequest(str); + } catch (std::exception& e) { + DLOG(INFO) << e.what(); + break; + } + } + + if (parser_state_ == STATE::PARSE_GET_REQUEST) { + parser_state_ = STATE::PARSE_REQUEST; + return STATUS::GET_SUCCESS; + } + + if (parser_state_ != STATE::PARSE_FINISHED) { + return is_line_received ? STATUS::FAILED : STATUS::EMPTY_REQUEST; + } + + return STATUS::SUCCEED; +} + +Parser::Parser() + : parser_state_(STATE::PARSE_REQUEST) + , header_size_(0) +{ + ; // Do nothing. +} + +bool Parser::IsFinished() const +{ + return parser_state_ == STATE::PARSE_FINISHED; +} + +std::string Parser::RetrieveAuthCode(std::string& state) +{ + if (body_.contains("code")) + { + return body_["code"]; + } + + return ""; +} diff --git a/driver/plugin/federated/http/Parser.h b/driver/plugin/federated/http/Parser.h new file mode 100755 index 0000000..6b13934 --- /dev/null +++ b/driver/plugin/federated/http/Parser.h @@ -0,0 +1,114 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +enum class STATE +{ + PARSE_REQUEST, + PARSE_HEADER, + PARSE_BODY, + PARSE_FINISHED, + PARSE_GET_REQUEST +}; + +enum class STATUS +{ + SUCCEED, + FAILED, + EMPTY_REQUEST, + GET_SUCCESS +}; + +/* +* This class is used to parse the HTTP POST request +* and retrieve authorization code. +*/ +class Parser { + public: + + Parser(); + + ~Parser() = default; + + /* + * Initiate request parsing. + * + * Return true if parse was successful, otherwise false. + */ + STATUS Parse(std::istream &in); + + /* + * Check if parser is finished to parse the POST request. + * + * Return true if parsing was successfully finished, otherwise false. + */ + bool IsFinished() const; + + /* + * Retrieve authorization code. + * + * Return received authorization code or empty string. + */ + std::string RetrieveAuthCode(std::string& state); + + private: + /* + * Parse received POST request line by line. + * + * Return void or throw an exception if parse wasn't successful. + */ + void ParsePostRequest(std::string& str); + + /* + * Parse request-line and perform verification for: + * method, Request-URI and HTTP-Version. + * + * Return void or throw an exception if parse wasn't successful. + */ + void ParseRequestLine(std::string& str); + + /* + * Parse request header line. + * + * Return void or throw an exception if parse wasn't successful. + */ + void ParseHeaderLine(std::string& str); + + /* + * Parse request body line in application/x-www-form-urlencoded format. + * + * Return void or throw an exception if parse wasn't successful. + */ + void ParseBodyLine(std::string& str); + + /* + * Expected METHOD, URI and HTTP VERSION in request line. + */ + const std::string METHOD = "POST"; + // const std::string URI = "/redshift/"; + const std::string HTTP_VERSION = "HTTP/1.1"; + const int MAX_HEADER_SIZE = 8192; + + const std::string GET_METHOD = "GET"; + const std::string PKCE_URI = "/?code="; + + STATE parser_state_; + size_t header_size_; + std::unordered_map header_; + std::unordered_map body_; + +}; diff --git a/driver/plugin/federated/http/Selector.cpp b/driver/plugin/federated/http/Selector.cpp new file mode 100644 index 0000000..5dd7f79 --- /dev/null +++ b/driver/plugin/federated/http/Selector.cpp @@ -0,0 +1,56 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Selector.h" + +void Selector::Register(SOCKET sfd) +{ + if (sfd == -1) { + return; + } + + FD_SET(sfd, &master_fds_); + +#ifdef WIN32 + max_fd_ = max(max_fd_, sfd); +#else + max_fd_ = std::max(max_fd_, sfd); +#endif +} + +void Selector::Unregister(SOCKET sfd) +{ + FD_CLR(sfd, &master_fds_); +} + +bool Selector::Select(struct timeval *tv) +{ + // As select will modify the file descriptor set we should keep temporary set to reflect the ready fd. + fd_set read_fds = master_fds_; + + if (select(max_fd_ + 1, &read_fds, nullptr, nullptr, tv) > 0) { + for (SOCKET sfd = 0; sfd <= max_fd_; sfd++) { + if (FD_ISSET(sfd, &read_fds)) { + return true; + } + } + } + + return false; +} + +Selector::Selector() : max_fd_(0) +{ + FD_ZERO(&master_fds_); +} diff --git a/driver/plugin/federated/http/Selector.h b/driver/plugin/federated/http/Selector.h new file mode 100755 index 0000000..d48a7de --- /dev/null +++ b/driver/plugin/federated/http/Selector.h @@ -0,0 +1,52 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "SocketSupport.h" + +#include + +/* +* This class is used to support insertion/deletion operations on +* master descriptor set and selecting incoming connections. +*/ +class Selector +{ + public: + Selector(); + + ~Selector() = default; + + /* + * Update master descriptor set with new socket. + */ + void Register(SOCKET sfd); + + /* + * Remove socket from master descriptor set. + */ + void Unregister(SOCKET sfd); + + /* + * If any incoming connections is readable return true, or + * false otherwise. + */ + bool Select(struct timeval* tv); + + private: + + SOCKET max_fd_; + fd_set master_fds_; +}; diff --git a/driver/plugin/federated/http/Socket.cpp b/driver/plugin/federated/http/Socket.cpp new file mode 100755 index 0000000..328f7f9 --- /dev/null +++ b/driver/plugin/federated/http/Socket.cpp @@ -0,0 +1,214 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "AddrInformation.h" +#include "Socket.h" + +#include +#include +#include + +int Socket::Receive(char *buffer, int length, int flags) const +{ + int nbytes = 0; + int filled = 0; + auto start = std::chrono::system_clock + ::now(); + + const int receive_wait = 200; + // Give a chance to fully receive packet in case if there is no data in + // non-blocking socket. + while ((std::chrono::system_clock::now() - start < std::chrono::seconds(1))) { + nbytes = static_cast(recv(socket_fd_, buffer + filled, length - filled - 1, 0)); + + if (nbytes <= 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(receive_wait)); + } else { + filled += nbytes; + } + } + + return filled == 0 ? nbytes : filled; +} + +int Socket::Send(const char *buffer, int length, int flags) const + { + int nbytes = 0; + int sent = 0; + + while ((nbytes = static_cast(send(socket_fd_, buffer + sent, length - sent, flags))) > 0) { + sent += nbytes; + } + + return sent == 0 ? nbytes : sent; +} + +void Socket::PrepareListenSocket(const std::string& port) +{ + const AddrInformation addr_info(port); + + for (const auto& ptr : addr_info) { + socket_fd_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol); + + if (socket_fd_ == INVALID_SOCKET) { + continue; + } + + if (!SetReusable()) { + Close(); + + throw std::runtime_error("Unable to reuse the address."); + } + + if (Bind(ptr->ai_addr, ptr->ai_addrlen) == 0) { + break; + } + + Close(); + } + + if (!SetNonBlocking()) { + Close(); + + throw std::runtime_error("Unable to set socket to non-blocking mode."); + } + + if ((socket_fd_ == INVALID_SOCKET) || (Listen(CONNECTION_BACKLOG))) { + Close(); + + throw std::runtime_error("Can not start listening on port: " + port); + } +} + +bool Socket::Close() +{ + if (socket_fd_ == -1) { + return false; + } + +#if (defined(_WIN32) || defined(_WIN64)) + int res = closesocket(socket_fd_); +#else + const int res = close(socket_fd_); +#endif + + socket_fd_ = INVALID_SOCKET; + + return res == 0; +} + +Socket::Socket() : socket_fd_(INVALID_SOCKET) +{ + ; // Do nothing. +} + +Socket::Socket(SOCKET sfd) : socket_fd_(sfd) +{ + ; // Do nothing. +} + +Socket::Socket(Socket&& s) +{ + socket_fd_ = std::move(s.socket_fd_); + + s.socket_fd_ = INVALID_SOCKET; +} + +Socket& Socket::operator=(Socket&& s) +{ + socket_fd_ = std::move(s.socket_fd_); + + s.socket_fd_ = INVALID_SOCKET; + + return *this; +} + +Socket::~Socket() +{ + Close(); +} + +bool Socket::SetNonBlocking() const +{ +#if (defined(_WIN32) || defined(_WIN64)) + unsigned long mode = 1; + return ioctlsocket(socket_fd_, FIONBIO, &mode) == 0 ? true : false; +#else + int flags = fcntl(socket_fd_, F_GETFL, 0); + return fcntl(socket_fd_, F_SETFL, flags | O_NONBLOCK) == 0 ? true : false; +#endif +} + +bool Socket::IsNonBlockingError() const +{ +#if (defined(_WIN32) || defined(_WIN64)) + return WSAGetLastError() == WSAEWOULDBLOCK; +#else + return errno == EWOULDBLOCK; +#endif +} + +void Socket::Register(Selector& selector) const +{ + selector.Register(socket_fd_); +} + +void Socket::Unregister(Selector& selector) const +{ + selector.Unregister(socket_fd_); +} + +int Socket::Bind(const struct sockaddr *address, size_t address_len) const +{ + return bind(socket_fd_, address, (int)address_len); +} + +int Socket::GetListenPort() const +{ + sockaddr_storage addr; + socklen_t len = sizeof(addr); + int port = 0; + + getsockname(socket_fd_, (struct sockaddr*)&addr, &len); + port = htons(((sockaddr_in*)&addr)->sin_port); + + return port; +} + +int Socket::Listen(int backlog) const +{ + return listen(socket_fd_, backlog); +} + +Socket Socket::Accept() const +{ + sockaddr_storage remoteaddr; + socklen_t addrlen = sizeof(remoteaddr); + + return Socket(accept(socket_fd_, (struct sockaddr*)&remoteaddr, &addrlen)); +} + +bool Socket::SetReusable() const +{ + int yes = 1; + + // Windows: int setsockopt(SOCKET s, int level, int optname, const char *optval, int optlen); + // The optval parameter has ponter to const char, but according to the MSDN we should use int: + // To enable a Boolean option, the optval parameter points to a nonzero integer. + // To disable the option optval points to an integer equal to zero. + // The optlen parameter should be equal to sizeof(int) for Boolean options. + // On Linux the optval parameter is pointer to void. + return setsockopt(socket_fd_, SOL_SOCKET, SO_REUSEADDR, + (char *)&yes, sizeof(yes)) == 0; +} diff --git a/driver/plugin/federated/http/Socket.h b/driver/plugin/federated/http/Socket.h new file mode 100755 index 0000000..2d4e465 --- /dev/null +++ b/driver/plugin/federated/http/Socket.h @@ -0,0 +1,126 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOCKET_H_ +#define SOCKET_H_ + +#pragma once + +#include "SocketSupport.h" +#include "Selector.h" +#include + +/* +* This class is used to wrap the network socket +* in order to have cross-platfrom code to send and receive +* data from incoming connections. +*/ +class Socket +{ + public: + Socket(); + + Socket( SOCKET sfd); + + Socket(const Socket& s) = delete; + + Socket& operator=(const Socket& s) = delete; + + Socket(Socket&& s); + + Socket& operator=(Socket&& s); + + ~Socket(); + + /* + * Close a socket file descriptor, return true on success + * or false in case of error. + */ + bool Close(); + + /* + * Get port where socket is listening. + */ + int GetListenPort() const; + + /* + * Listen for connections on a socket and return zero on success, + * or -1 in case of error. + */ + int Listen(int backlog) const; + + /* + * Accept a connection on a socket and return StreamSocket object. + */ + Socket Accept() const; + + /* + * Forcibly bind to a port in use by another socket and return true on success, + * or false in case of error. + */ + bool SetReusable() const; + + /* + * Set socket to non-blocking mode and return true on success + * or false in case of error. + */ + bool SetNonBlocking() const; + + /* + * Return true if error is caused by non-blocking mode of socket + * otherwise return false. + */ + bool IsNonBlockingError() const; + + /* + * Prepare socket to handle incoming connections. + */ + void PrepareListenSocket(const std::string& port); + + /* + * Register socket in master file descriptor set using Selector class. + */ + void Register(Selector& selector) const; + + /* + * Clear socket in master file descriptor set using Selector class. + */ + void Unregister(Selector& selector) const; + + /* + * Receive a message from a socket and return the number of bytes received, + * or -1 if an error occurred. + */ + int Receive(char *buffer, int length, int flags) const; + + /* + * Send a message on a socket and return the number of bytes sent, + * or -1 if an error occured. + */ + int Send(const char *buffer, int length, int flags) const; + + /* + * Bind a name to a socket and return zero on success, + * or -1 in case of error. + */ + int Bind(const struct sockaddr *address, size_t address_len) const; + + private: + + const int CONNECTION_BACKLOG = 10; + + SOCKET socket_fd_; +}; + +#endif // SOCKET_H_ diff --git a/driver/plugin/federated/http/SocketStream.cpp b/driver/plugin/federated/http/SocketStream.cpp new file mode 100644 index 0000000..1d41a3c --- /dev/null +++ b/driver/plugin/federated/http/SocketStream.cpp @@ -0,0 +1,50 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "SocketStream.h" + +SocketStream::SocketStream(Socket& s) : received_size_(0), socket_(s) +{ + setg(input_buffer_, input_buffer_, input_buffer_); +} + +SocketStream::~SocketStream() +{ + sync(); +} + +int SocketStream::underflow() +{ + if (gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + + int received_bytes = socket_.Receive(input_buffer_, SIZE - 1, 0); + + /* + * Return EOF in the following cases: + * If the received bytes less than zero (error situation); + * If the received bytes equal to zero (socket peer has performed an orderly shutdown); + * If the length of received packets more than MAX_SIZE. + */ + if ((received_bytes <= 0) || (received_size_ + received_bytes > MAX_SIZE)) { + return traits_type::eof(); + } + + received_size_ += received_bytes; + + setg(input_buffer_, input_buffer_, input_buffer_ + received_bytes); + + return traits_type::to_int_type(*gptr()); +} \ No newline at end of file diff --git a/driver/plugin/federated/http/SocketStream.h b/driver/plugin/federated/http/SocketStream.h new file mode 100755 index 0000000..5d1eb7b --- /dev/null +++ b/driver/plugin/federated/http/SocketStream.h @@ -0,0 +1,59 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOCKET_STREAM_H_ +#define SOCKET_STREAM_H_ + +#pragma once + +#include "Socket.h" + +#include + +/* +* This class is used to wrap socket by std::stream. +*/ +class SocketStream : public std::streambuf +{ + public: + /* + * Construct object and set the value for the pointers that define + * the boundaries of the buffered portion of the controlled INPUT sequence. + */ + SocketStream(Socket& s); + + /* + * Call sync inside the destructor to synchronize the contents in the stream buffer + * with those of the associated character sequence. + */ + virtual ~SocketStream(); + + protected: + /* + * Virtual function called by other member functions to get the current character + * in the controlled input sequence without changing the current position. + * It is called by public member functions such as sgetc to request + * a new character when there are no read positions available at the get pointer (gptr). + */ + virtual int underflow(); + + static const int SIZE = 2048; + static const int MAX_SIZE = 16384; + + int received_size_; + char input_buffer_[SIZE]; + Socket& socket_; +}; + +#endif // SOCKET_STREAM_H_ diff --git a/driver/plugin/federated/http/SocketSupport.h b/driver/plugin/federated/http/SocketSupport.h new file mode 100644 index 0000000..074c0d2 --- /dev/null +++ b/driver/plugin/federated/http/SocketSupport.h @@ -0,0 +1,35 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if (defined(_WIN32) || defined(_WIN64)) + +#include +#include +#include "Ws2tcpip.h" +#pragma comment(lib, "Ws2_32.lib") + +#else /* Linux or MAC */ + +#include +#include +#include +#include +#include + +typedef int SOCKET; +#define INVALID_SOCKET -1 + +#endif diff --git a/driver/plugin/federated/http/WEBServer.cpp b/driver/plugin/federated/http/WEBServer.cpp new file mode 100755 index 0000000..1b84cf5 --- /dev/null +++ b/driver/plugin/federated/http/WEBServer.cpp @@ -0,0 +1,153 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "../../../util/logger_wrapper.h" +#include "HtmlResponse.h" +#include "SocketStream.h" +#include "WEBServer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +WEBServer::WEBServer( std::string& state, + std::string& port, std::string& timeout) : + state_(state), + port_(port), + timeout_(std::stoi(timeout)), + selector_(), + parser_(), + listen_socket_(), + listen_port_(0), + connections_counter_(0), + listening_(false) +{ + ; // Do nothing. +} + +void WEBServer::LaunchServer() +{ + // Create thread to launch the server to listen the incoming connection. + // Waiting for the redirect response from the /oauth2/authorize. + thread_ = std::thread(&WEBServer::ListenerThread, this); +} + +void WEBServer::Join() +{ + if (thread_.joinable()) { + thread_.join(); + } +} + +std::string WEBServer::GetCode() const +{ + return code_; +} + +int WEBServer::GetListenPort() const +{ + return listen_port_; +} + +bool WEBServer::IsListening() const +{ + return listening_.load(); +} + +bool WEBServer::IsTimeout() const +{ + return connections_counter_ > 0 ? false : true; +} + +void WEBServer::Cancel() { + cancel_ = true; +} + +bool WEBServer::WEBServerInit() +{ + // Prepare the environment for get the socket description. + try { + listen_socket_.PrepareListenSocket(port_); + listen_socket_.Register(selector_); + listen_port_ = listen_socket_.GetListenPort(); + } catch (std::exception& e) { + DLOG(INFO) << "Exception: " << e.what(); + + return false; + } + + return true; +} + +void WEBServer::HandleConnection() +{ + /* Trying to accept the pending incoming connection. */ + Socket ssck(listen_socket_.Accept()); + + ++connections_counter_; + + if (ssck.SetNonBlocking()) { + SocketStream socket_buffer(ssck); + std::istream socket_stream(&socket_buffer); + + const STATUS status = parser_.Parse(socket_stream); + + if (status == STATUS::SUCCEED) { + ssck.Send(ValidResponse.c_str(), static_cast(ValidResponse.size()), 0); + } else if (status == STATUS::FAILED) { + ssck.Send(InvalidResponse.c_str(), static_cast(InvalidResponse.size()), 0); + } else if (status == STATUS::GET_SUCCESS) { + const std::string response = std::vformat(GetResponse, std::make_format_args(listen_port_)); + ssck.Send(response.c_str(), static_cast(response.size()), 0); + } else { + /* Nothing is received from socket. Continue to listen. */ + return; + } + } +} + +void WEBServer::Listen() +{ + // Set timeout for non-blocking socket to 1 sec to pass it to Select. + struct timeval tv = { .tv_sec=1, .tv_usec=0 }; + + if (selector_.Select(&tv)) { + HandleConnection(); + } +} + +void WEBServer::ListenerThread() +{ + if (!WEBServerInit()) { + DLOG(INFO) << "WEBServerInit Failed"; + return; + } + + auto start = std::chrono::system_clock::now(); + + listening_.store(true); + + while ((std::chrono::system_clock::now() - start < std::chrono::seconds(timeout_)) && !parser_.IsFinished()) { + Listen(); + } + + code_ = parser_.RetrieveAuthCode(state_); + + listen_socket_.Close(); +} diff --git a/driver/plugin/federated/http/WEBServer.h b/driver/plugin/federated/http/WEBServer.h new file mode 100755 index 0000000..4be1205 --- /dev/null +++ b/driver/plugin/federated/http/WEBServer.h @@ -0,0 +1,118 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBSERVER_H_ +#define WEBSERVER_H_ + +#pragma once + +#include "Parser.h" +#include "Selector.h" +#include "Socket.h" + +#include +#include + +/* +* This class is used to launch the HTTP WEB server in separate thread +* to wait for the redirect response from the /oauth2/authorize and +* extract the authorization code from it. +*/ +class WEBServer +{ + public: + WEBServer( + std::string& state, + std::string& port, + std::string& timeout); + + ~WEBServer() = default; + + /* + * Launch the HTTP WEB server in separate thread. + */ + void LaunchServer(); + + /* + * Wait until HTTP WEB server is finished. + */ + void Join(); + + /* + * Extract the authorization code from response. + */ + std::string GetCode() const; + + /* + * Get port where server is listening. + */ + int GetListenPort() const; + + /* + * Extract the SAML Assertion from response. + */ + std::string GetSamlResponse() const; + + /* + * If server is listening for connections return true, otherwise return false. + */ + bool IsListening() const; + + /* + * If timeout happened return true, otherwise return false. + */ + bool IsTimeout() const; + + /* + * Cancel listen loop prematurely without waiting for the timeout. + */ + void Cancel(); + + private: + /* + * Main HTTP WEB server function that perform initialization and + * listen for incoming connections for specified time by user. + */ + void ListenerThread(); + + /* + * If incoming connection is available call HandleConnection. + */ + void Listen(); + + /* + * Launch parser if incoming connection is acceptable. + */ + void HandleConnection(); + + /* + * Perform socket preparation to launch the HTTP WEB server. + */ + bool WEBServerInit(); + + std::string state_; + std::string port_; + int timeout_; + std::string code_; + std::thread thread_; + Selector selector_; + Parser parser_; + Socket listen_socket_; + int listen_port_; + int connections_counter_; + std::atomic listening_; + bool cancel_ = false; +}; + +#endif // WEBSERVER_H_ diff --git a/driver/plugin/federated/http/WEBServer_utils.cpp b/driver/plugin/federated/http/WEBServer_utils.cpp new file mode 100644 index 0000000..49c4f2a --- /dev/null +++ b/driver/plugin/federated/http/WEBServer_utils.cpp @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "WEBServer_utils.h" + +int WebServerUtils::GenerateRandomInteger(int low, int high) +{ + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution<> dist(low, high); + + return dist(generator); +} + +std::string WebServerUtils::GenerateState() +{ + const char STATE_CHAR_LIST[37] = "0123456789abcdefghijklmnopqrstuvwxyz"; + const int chars_size = (sizeof(STATE_CHAR_LIST) / sizeof(*STATE_CHAR_LIST)) - 1; + const int rand_size = GenerateRandomInteger(9, chars_size - 1); + std::string state; + + state.reserve(rand_size); + + for (int i = 0; i < rand_size; ++i) { + state.push_back(STATE_CHAR_LIST[GenerateRandomInteger(0, rand_size)]); + } + + return state; +} diff --git a/driver/plugin/federated/http/WEBServer_utils.h b/driver/plugin/federated/http/WEBServer_utils.h new file mode 100644 index 0000000..d159738 --- /dev/null +++ b/driver/plugin/federated/http/WEBServer_utils.h @@ -0,0 +1,25 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBSERVER_UTILS_H_ +#define WEBSERVER_UTILS_H_ + +#include + +namespace WebServerUtils { + int GenerateRandomInteger(int low, int high); + std::string GenerateState(); +} + +#endif // WEBSERVER_UTILS_H_ diff --git a/driver/plugin/federated/okta_auth_plugin.cpp b/driver/plugin/federated/okta_auth_plugin.cpp index 95c66ce..23028ee 100644 --- a/driver/plugin/federated/okta_auth_plugin.cpp +++ b/driver/plugin/federated/okta_auth_plugin.cpp @@ -12,17 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "okta_auth_plugin.h" +#if (defined(_WIN32) || defined(_WIN64)) + #include + #include + #undef GetObject +#endif -#include +#include "okta_auth_plugin.h" #include "html_util.h" +#include "http/WEBServer.h" +#include "http/WEBServer_utils.h" +#include "saml_util.h" -#include "../../util/aws_sdk_helper.h" #include "../../util/connection_string_keys.h" #include "../../util/logger_wrapper.h" +#include "../../util/rds_strings.h" #include "../../util/rds_utils.h" -#include "saml_util.h" OktaAuthPlugin::OktaAuthPlugin(DBC *dbc) : OktaAuthPlugin(dbc, nullptr) {} @@ -113,7 +119,7 @@ SQLRETURN OktaAuthPlugin::Connect( ret = next_plugin->Connect(ConnectionHandle, WindowHandle, OutConnectionString, BufferLength, StringLengthPtr, DriverCompletion); // Unsuccessful connection using cached token - // Skip cache and generate a new token to retry + // Skip cache and generate a new token to retry if (!SQL_SUCCEEDED(ret) && token.second) { LOG(WARNING) << "Cached token failed to connect. Retrying with fresh token"; // Update AWS Credentials @@ -145,6 +151,16 @@ OktaSamlUtil::OktaSamlUtil( } sign_in_url = "https://" + idp_endpoint + ":" + idp_port + "/app/amazon_aws/" + app_id + "/sso/saml" + "?onetimetoken="; session_token_url = "https://" + idp_endpoint + ":" + idp_port + "/api/v1/authn"; + + const std::string mfa_type_str = connection_attributes.contains(KEY_MFA_TYPE) ? + connection_attributes.at(KEY_MFA_TYPE) : ""; + if (mfa_type_table.contains(mfa_type_str)) { + mfa_type = mfa_type_table.at(mfa_type_str); + } + mfa_port = connection_attributes.contains(KEY_MFA_PORT) ? + connection_attributes.at(KEY_MFA_PORT) : DEFAULT_PORT; + mfa_timeout = connection_attributes.contains(KEY_MFA_TIMEOUT) ? + connection_attributes.at(KEY_MFA_TIMEOUT) : DEFAULT_MFA_TIMEOUT; } std::string OktaSamlUtil::GetSamlAssertion() @@ -214,10 +230,169 @@ std::string OktaSamlUtil::GetSessionToken() LOG(ERROR) << "Unable to parse JSON from response"; return ""; } + + if (mfa_type != NONE) { + const Aws::Utils::Json::JsonView json_view = json_val.View(); + + if (!json_view.KeyExists("stateToken")) { + LOG(ERROR) << "Could not find state token in JSON"; + return ""; + } + + const std::string state_token = json_view.GetString("stateToken"); + const Aws::Utils::Json::JsonView embedded_view = json_view.GetObject("_embedded"); + Aws::Utils::Array factor_views = embedded_view.GetArray("factors"); + + std::string factor_id; + for (int i = 0; i < factor_views.GetLength(); i++) { + const std::string type = factor_views[i].GetString("factorType"); + if (mfa_type == TOTP && type == "token:software:totp" || mfa_type == PUSH && type == "push") { + factor_id = factor_views[i].GetString("id"); + } + } + + if (factor_id.empty()) { + LOG(ERROR) << "Could not find factor in JSON"; + return ""; + } + + const std::string verify_url = session_token_url + "/factors/" + factor_id + "/verify"; + if (mfa_type == TOTP) { + return VerifyTOTPChallenge(verify_url, state_token); + } + if (mfa_type == PUSH) { + return VerifyPushChallenge(verify_url, state_token); + } + } + const Aws::Utils::Json::JsonView json_view = json_val.View(); if (!json_view.KeyExists("sessionToken")) { LOG(ERROR) << "Could not find session token in JSON"; return ""; } + return json_view.GetString("sessionToken"); } + +std::string OktaSamlUtil::VerifyTOTPChallenge( + const std::string &verify_url, + const std::string &state_token) +{ + std::string state = WebServerUtils::GenerateState(); + WEBServer srv(state, mfa_port, mfa_timeout); + + srv.LaunchServer(); + + const std::string mfa_form_url = WEBSERVER_HOST + ":" + mfa_port; + try { +#if (defined(_WIN32) || defined(_WIN64)) + const HINSTANCE result = ShellExecute(NULL, RDS_TSTR(std::string("open")).c_str(), RDS_TSTR(mfa_form_url).c_str(), NULL, NULL, SW_SHOWNORMAL); + if (reinterpret_cast(result) <= 32) { + srv.Cancel(); + } +#else +#if (defined(LINUX) || defined(__linux__)) + const int result = system(("xdg-open " + mfa_form_url).c_str()); +#else + const int result = system(("open " + mfa_form_url).c_str()); +#endif + if (result != 0) { + srv.Cancel(); + } +#endif + } catch (const std::exception & e) { + srv.Cancel(); + srv.Join(); + LOG(ERROR) << "Could not open browser to obtain MFA token: " << e.what(); + return ""; + } + + srv.Join(); + + const std::string pass_code = srv.GetCode(); + if (pass_code.empty()) { + LOG(ERROR) << "MFA Authorization code was not obtained"; + return ""; + } + + const std::shared_ptr req = Aws::Http::CreateHttpRequest( + verify_url, Aws::Http::HttpMethod::HTTP_POST, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + Aws::Utils::Json::JsonValue json_body; + json_body + .WithString("stateToken", state_token) + .WithString("passCode", pass_code); + const Aws::String json_str = json_body.View().WriteReadable(); + const Aws::String json_len = Aws::Utils::StringUtils::to_string(json_str.size()); + req->SetContentType("application/json"); + req->AddContentBody(Aws::MakeShared("", json_str)); + req->SetContentLength(json_len); + const std::shared_ptr response = http_client->MakeRequest(req); + + // Check resp status + if (response->GetResponseCode() != Aws::Http::HttpResponseCode::OK) { + LOG(ERROR) << "OKTA request returned bad HTTP response code: " << response->GetResponseCode(); + if (response->HasClientError()) { + LOG(ERROR) << "HTTP Client Error: " << response->GetClientErrorMessage(); + } + return ""; + } + + const Aws::Utils::Json::JsonValue json_val(response->GetResponseBody()); + if (!json_val.WasParseSuccessful()) { + LOG(ERROR) << "Unable to parse JSON from response"; + return ""; + } + + const Aws::Utils::Json::JsonView json_view = json_val.View(); + if (!json_view.KeyExists("sessionToken")) { + LOG(ERROR) << "Could not find session token in JSON"; + return ""; + } + + return json_view.GetString("sessionToken"); +} + +std::string OktaSamlUtil::VerifyPushChallenge( + const std::string &verify_url, + const std::string &state_token) +{ + const std::chrono::time_point end_time = std::chrono::system_clock::now() + std::chrono::seconds(std::strtol(mfa_timeout.c_str(), nullptr, 0)); + while (std::chrono::system_clock::now() < end_time) { + const std::shared_ptr req = Aws::Http::CreateHttpRequest( + verify_url, Aws::Http::HttpMethod::HTTP_POST, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + Aws::Utils::Json::JsonValue json_body; + json_body.WithString("stateToken", state_token); + const Aws::String json_str = json_body.View().WriteReadable(); + const Aws::String json_len = Aws::Utils::StringUtils::to_string(json_str.size()); + req->SetContentType("application/json"); + req->AddContentBody(Aws::MakeShared("", json_str)); + req->SetContentLength(json_len); + const std::shared_ptr response = http_client->MakeRequest(req); + + // Check resp status + if (response->GetResponseCode() != Aws::Http::HttpResponseCode::OK) { + LOG(ERROR) << "OKTA request returned bad HTTP response code: " << response->GetResponseCode(); + if (response->HasClientError()) { + LOG(ERROR) << "HTTP Client Error: " << response->GetClientErrorMessage(); + } + } else { + const Aws::Utils::Json::JsonValue json_val(response->GetResponseBody()); + if (!json_val.WasParseSuccessful()) { + LOG(ERROR) << "Unable to parse JSON from response"; + continue; + } + + const Aws::Utils::Json::JsonView json_view = json_val.View(); + if (!json_view.KeyExists("sessionToken")) { + LOG(ERROR) << "Could not find session token in JSON"; + continue; + } + + return json_view.GetString("sessionToken"); + } + + std::this_thread::sleep_for(std::chrono::seconds(VERIFY_PUSH_INTERVAL)); + } + LOG(ERROR) << "The MFA challenge was not completed in time"; + return ""; +} diff --git a/driver/plugin/federated/okta_auth_plugin.h b/driver/plugin/federated/okta_auth_plugin.h index d0239a4..8f4d44c 100644 --- a/driver/plugin/federated/okta_auth_plugin.h +++ b/driver/plugin/federated/okta_auth_plugin.h @@ -21,6 +21,18 @@ #include "../base_plugin.h" #include "../../driver.h" +typedef enum { + NONE, + TOTP, + PUSH +} MfaType; + +static std::map const mfa_type_table = { + {"", MfaType::NONE}, + {VALUE_MFA_TOTP, MfaType::TOTP}, + {VALUE_MFA_PUSH, MfaType::PUSH} +}; + class OktaSamlUtil : public SamlUtil { public: OktaSamlUtil(const std::map &connection_attributes); @@ -30,8 +42,19 @@ class OktaSamlUtil : public SamlUtil { private: std::string GetSessionToken(); + std::string VerifyTOTPChallenge(const std::string &verify_url, const std::string &state_token); + + std::string VerifyPushChallenge(const std::string &verify_url, const std::string &state_token); + + static inline const std::string DEFAULT_MFA_TIMEOUT = "60"; + static inline const int VERIFY_PUSH_INTERVAL = 5; + static inline const std::string DEFAULT_PORT = "8080"; + static inline const std::string WEBSERVER_HOST = "http://127.0.0.1"; std::string sign_in_url; std::string session_token_url; + MfaType mfa_type; + std::string mfa_port; + std::string mfa_timeout; static inline const std::regex SAML_RESPONSE_PATTERN = std::regex(""); }; diff --git a/driver/util/attribute_validator.cpp b/driver/util/attribute_validator.cpp index 4b68941..275fa55 100644 --- a/driver/util/attribute_validator.cpp +++ b/driver/util/attribute_validator.cpp @@ -34,7 +34,9 @@ bool AttributeValidator::ShouldKeyBeUnsignedInt(const std::string& key) { KEY_FAILOVER_TIMEOUT, KEY_LIMITLESS_MONITOR_INTERVAL_MS, KEY_ROUTER_MAX_RETRIES, - KEY_LIMITLESS_MAX_RETRIES + KEY_LIMITLESS_MAX_RETRIES, + KEY_MFA_PORT, + KEY_MFA_TIMEOUT }; return INTEGER_KEYS.contains(key); } diff --git a/driver/util/connection_string_helper.h b/driver/util/connection_string_helper.h index f97676a..a607a0d 100644 --- a/driver/util/connection_string_helper.h +++ b/driver/util/connection_string_helper.h @@ -48,6 +48,9 @@ static std::unordered_set const aws_odbc_key_set = { KEY_HTTP_CONNECT_TIMEOUT, KEY_RELAY_PARTY_ID, KEY_APP_ID, + KEY_MFA_TYPE, + KEY_MFA_PORT, + KEY_MFA_TIMEOUT, KEY_DATABASE_DIALECT, KEY_HOST_SELECTOR_STRATEGY, KEY_ENABLE_FAILOVER, diff --git a/driver/util/connection_string_keys.h b/driver/util/connection_string_keys.h index 435d570..f716eb5 100644 --- a/driver/util/connection_string_keys.h +++ b/driver/util/connection_string_keys.h @@ -75,6 +75,11 @@ #define KEY_RELAY_PARTY_ID "RELAY_PARTY_ID" /* OKTA */ #define KEY_APP_ID "APP_ID" +#define KEY_MFA_TYPE "MFA_TYPE" +#define KEY_MFA_PORT "MFA_PORT" +#define KEY_MFA_TIMEOUT "MFA_TIMEOUT" +#define VALUE_MFA_TOTP "TOTP" +#define VALUE_MFA_PUSH "PUSH" /* Host Selectors */ #define KEY_HOST_SELECTOR_STRATEGY "HOST_SELECTOR_STRATEGY"